+ - Colossal-LLaMA-2: One Half-Day of Training Using a Few Hundred Dollars Yields Similar Results to Mainstream Large Models, Open-Source and Commercial-Free Domain-Specific Llm Solution
- ColossalChat: An Open-Source Solution for Cloning ChatGPT With a Complete RLHF Pipeline
- AIGC: Acceleration of Stable Diffusion
- Biomedicine: Acceleration of AlphaFold Protein Structure
@@ -127,6 +127,36 @@ distributed training and inference in a few lines.
## Colossal-AI in the Real World
+### Colossal-LLaMA-2
+
+- One half-day of training using a few hundred dollars yields similar results to mainstream large models, open-source and commercial-free domain-specific LLM solution.
+[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Colossal-LLaMA-2)
+[[blog]](https://www.hpc-ai.tech/blog/one-half-day-of-training-using-a-few-hundred-dollars-yields-similar-results-to-mainstream-large-models-open-source-and-commercial-free-domain-specific-llm-solution)
+[[model weights]](https://huggingface.co/hpcai-tech/Colossal-LLaMA-2-7b-base)
+
+| | Backbone | Tokens Consumed | | MMLU | CMMLU | AGIEval | GAOKAO | CEval |
+| :----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :-----: | :----: | :----: | :------------------------------: |
+| | | - | | 5-shot | 5-shot | 5-shot | 0-shot | 5-shot |
+| Baichuan-7B | - | 1.2T | | 42.32 (42.30) | 44.53 (44.02) | 38.72 | 36.74 | 42.80 |
+| Baichuan-13B-Base | - | 1.4T | | 50.51 (51.60) | 55.73 (55.30) | 47.20 | 51.41 | 53.60 |
+| Baichuan2-7B-Base | - | 2.6T | | 46.97 (54.16) | 57.67 (57.07) | 45.76 | 52.60 | 54.00 |
+| Baichuan2-13B-Base | - | 2.6T | | 54.84 (59.17) | 62.62 (61.97) | 52.08 | 58.25 | 58.10 |
+| ChatGLM-6B | - | 1.0T | | 39.67 (40.63) | 41.17 (-) | 40.10 | 36.53 | 38.90 |
+| ChatGLM2-6B | - | 1.4T | | 44.74 (45.46) | 49.40 (-) | 46.36 | 45.49 | 51.70 |
+| InternLM-7B | - | 1.6T | | 46.70 (51.00) | 52.00 (-) | 44.77 | 61.64 | 52.80 |
+| Qwen-7B | - | 2.2T | | 54.29 (56.70) | 56.03 (58.80) | 52.47 | 56.42 | 59.60 |
+| | | | | | | | | |
+| Llama-2-7B | - | 2.0T | | 44.47 (45.30) | 32.97 (-) | 32.60 | 25.46 | - |
+| Linly-AI/Chinese-LLaMA-2-7B-hf | Llama-2-7B | 1.0T | | 37.43 | 29.92 | 32.00 | 27.57 | - |
+| wenge-research/yayi-7b-llama2 | Llama-2-7B | - | | 38.56 | 31.52 | 30.99 | 25.95 | - |
+| ziqingyang/chinese-llama-2-7b | Llama-2-7B | - | | 33.86 | 34.69 | 34.52 | 25.18 | 34.2 |
+| TigerResearch/tigerbot-7b-base | Llama-2-7B | 0.3T | | 43.73 | 42.04 | 37.64 | 30.61 | - |
+| LinkSoul/Chinese-Llama-2-7b | Llama-2-7B | - | | 48.41 | 38.31 | 38.45 | 27.72 | - |
+| FlagAlpha/Atom-7B | Llama-2-7B | 0.1T | | 49.96 | 41.10 | 39.83 | 33.00 | - |
+| IDEA-CCNL/Ziya-LLaMA-13B-v1.1 | Llama-13B | 0.11T | | 50.25 | 40.99 | 40.04 | 30.54 | - |
+| | | | | | | | | |
+| **Colossal-LLaMA-2-7b-base** | Llama-2-7B | **0.0085T** | | 53.06 | 49.89 | 51.48 | 58.82 | 50.2 |
+
### ColossalChat
@@ -224,7 +254,7 @@ Acceleration of [AlphaFold Protein Structure](https://alphafold.ebi.ac.uk/)
- 70 billion parameter LLaMA2 model training accelerated by 195%
-[[code]](https://github.com/hpcaitech/ColossalAI/tree/example/llama/examples/language/llama)
+[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama2)
[[blog]](https://www.hpc-ai.tech/blog/70b-llama2-training)
### LLaMA1
@@ -472,7 +502,7 @@ To cite this project, you can use the following BibTeX citation.
}
```
-Colossal-AI has been accepted as official tutorial by top conferences [NeurIPS](https://nips.cc/), [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/),
+Colossal-AI has been accepted as official tutorial by top conferences [NeurIPS](https://nips.cc/), [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/),
[PPoPP](https://ppopp23.sigplan.org/), [CVPR](https://cvpr2023.thecvf.com/), [ISC](https://www.isc-hpc.com/), [NVIDIA GTC](https://www.nvidia.com/en-us/on-demand/session/gtcspring23-S51482/) ,etc.
(back to top)
diff --git a/applications/Chat/README.md b/applications/Chat/README.md
index 5a1187ab503d..d5be04ab9f44 100644
--- a/applications/Chat/README.md
+++ b/applications/Chat/README.md
@@ -200,7 +200,6 @@ We provide an online inference server and a benchmark. We aim to run inference o
We support 8-bit quantization (RTN), 4-bit quantization (GPTQ), and FP16 inference.
Online inference server scripts can help you deploy your own services.
-
For more details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/inference).
## Coati7B examples
@@ -414,7 +413,7 @@ You may contact us or participate in the following ways:
1. [Leaving a Star ⭐](https://github.com/hpcaitech/ColossalAI/stargazers) to show your like and support. Thanks!
2. Posting an [issue](https://github.com/hpcaitech/ColossalAI/issues/new/choose), or submitting a PR on GitHub follow the guideline in [Contributing](https://github.com/hpcaitech/ColossalAI/blob/main/CONTRIBUTING.md).
3. Join the Colossal-AI community on
- [Slack](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w),
+ [Slack](https://github.com/hpcaitech/public_assets/tree/main/colossalai/contact/slack),
and [WeChat(微信)](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png "qrcode") to share your ideas.
4. Send your official proposal to email contact@hpcaitech.com
@@ -428,7 +427,7 @@ Thanks so much to all of our amazing contributors!
-- An open-source low cost solution for cloning [ChatGPT](https://openai.com/blog/chatgpt/) with a complete RLHF pipeline. [[demo]](https://chat.colossalai.org)
+- An open-source low-cost solution for cloning [ChatGPT](https://openai.com/blog/chatgpt/) with a complete RLHF pipeline. [[demo]](https://chat.colossalai.org)
@@ -469,8 +468,7 @@ Coati is developed by ColossalAI Team:
- [ofey404](https://github.com/ofey404)
- [Wenhao Chen](https://github.com/CWHer)
-The Phd student from [(HPC-AI) Lab](https://ai.comp.nus.edu.sg/) also contributed a lot to this project.
-
+The PhD student from [(HPC-AI) Lab](https://ai.comp.nus.edu.sg/) also contributed a lot to this project.
- [Zangwei Zheng](https://github.com/zhengzangw)
- [Xue Fuzhao](https://github.com/XueFuzhao)
diff --git a/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py b/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py
index 90471ed727b0..0d0e2a7d34f5 100644
--- a/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py
+++ b/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py
@@ -27,7 +27,7 @@ def get_model_numel(model: nn.Module, strategy: Strategy) -> int:
def preprocess_batch(samples) -> dict:
input_ids = torch.stack(samples)
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
- return {'input_ids': input_ids, 'attention_mask': attention_mask}
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
def print_rank_0(*args, **kwargs) -> None:
@@ -39,32 +39,32 @@ def print_model_numel(model_dict: dict) -> None:
B = 1024**3
M = 1024**2
K = 1024
- outputs = ''
+ outputs = ""
for name, numel in model_dict.items():
- outputs += f'{name}: '
+ outputs += f"{name}: "
if numel >= B:
- outputs += f'{numel / B:.2f} B\n'
+ outputs += f"{numel / B:.2f} B\n"
elif numel >= M:
- outputs += f'{numel / M:.2f} M\n'
+ outputs += f"{numel / M:.2f} M\n"
elif numel >= K:
- outputs += f'{numel / K:.2f} K\n'
+ outputs += f"{numel / K:.2f} K\n"
else:
- outputs += f'{numel}\n'
+ outputs += f"{numel}\n"
print_rank_0(outputs)
def get_gpt_config(model_name: str) -> OPTConfig:
model_map = {
- '125m': OPTConfig.from_pretrained('facebook/opt-125m'),
- '350m': OPTConfig(hidden_size=1024, ffn_dim=4096, num_hidden_layers=24, num_attention_heads=16),
- '700m': OPTConfig(hidden_size=1280, ffn_dim=5120, num_hidden_layers=36, num_attention_heads=20),
- '1.3b': OPTConfig.from_pretrained('facebook/opt-1.3b'),
- '2.7b': OPTConfig.from_pretrained('facebook/opt-2.7b'),
- '3.5b': OPTConfig(hidden_size=3072, ffn_dim=12288, num_hidden_layers=32, num_attention_heads=32),
- '5.5b': OPTConfig(hidden_size=3840, ffn_dim=15360, num_hidden_layers=32, num_attention_heads=32),
- '6.7b': OPTConfig.from_pretrained('facebook/opt-6.7b'),
- '10b': OPTConfig(hidden_size=5120, ffn_dim=20480, num_hidden_layers=32, num_attention_heads=32),
- '13b': OPTConfig.from_pretrained('facebook/opt-13b'),
+ "125m": OPTConfig.from_pretrained("facebook/opt-125m"),
+ "350m": OPTConfig(hidden_size=1024, ffn_dim=4096, num_hidden_layers=24, num_attention_heads=16),
+ "700m": OPTConfig(hidden_size=1280, ffn_dim=5120, num_hidden_layers=36, num_attention_heads=20),
+ "1.3b": OPTConfig.from_pretrained("facebook/opt-1.3b"),
+ "2.7b": OPTConfig.from_pretrained("facebook/opt-2.7b"),
+ "3.5b": OPTConfig(hidden_size=3072, ffn_dim=12288, num_hidden_layers=32, num_attention_heads=32),
+ "5.5b": OPTConfig(hidden_size=3840, ffn_dim=15360, num_hidden_layers=32, num_attention_heads=32),
+ "6.7b": OPTConfig.from_pretrained("facebook/opt-6.7b"),
+ "10b": OPTConfig(hidden_size=5120, ffn_dim=20480, num_hidden_layers=32, num_attention_heads=32),
+ "13b": OPTConfig.from_pretrained("facebook/opt-13b"),
}
try:
return model_map[model_name]
@@ -73,20 +73,20 @@ def get_gpt_config(model_name: str) -> OPTConfig:
def main(args):
- if args.strategy == 'ddp':
+ if args.strategy == "ddp":
strategy = DDPStrategy()
- elif args.strategy == 'colossalai_gemini':
- strategy = GeminiStrategy(placement_policy='cuda', initial_scale=2**5)
- elif args.strategy == 'colossalai_gemini_cpu':
- strategy = GeminiStrategy(placement_policy='cpu', initial_scale=2**5)
- elif args.strategy == 'colossalai_zero2':
- strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
- elif args.strategy == 'colossalai_zero2_cpu':
- strategy = LowLevelZeroStrategy(stage=2, placement_policy='cpu')
- elif args.strategy == 'colossalai_zero1':
- strategy = LowLevelZeroStrategy(stage=1, placement_policy='cuda')
- elif args.strategy == 'colossalai_zero1_cpu':
- strategy = LowLevelZeroStrategy(stage=1, placement_policy='cpu')
+ elif args.strategy == "colossalai_gemini":
+ strategy = GeminiStrategy(placement_policy="static",initial_scale=2**5)
+ elif args.strategy == "colossalai_gemini_cpu":
+ strategy = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5)
+ elif args.strategy == "colossalai_zero2":
+ strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
+ elif args.strategy == "colossalai_zero2_cpu":
+ strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
+ elif args.strategy == "colossalai_zero1":
+ strategy = LowLevelZeroStrategy(stage=1, placement_policy="cuda")
+ elif args.strategy == "colossalai_zero1_cpu":
+ strategy = LowLevelZeroStrategy(stage=1, placement_policy="cpu")
else:
raise ValueError(f'Unsupported strategy "{args.strategy}"')
@@ -103,90 +103,106 @@ def main(args):
if args.use_kernels:
from coati.kernels import convert_to_xformer_model
- actor, critic, initial_model, reward_model = map(convert_to_xformer_model,
- (actor, critic, initial_model, reward_model))
+
+ actor, critic, initial_model, reward_model = map(
+ convert_to_xformer_model, (actor, critic, initial_model, reward_model)
+ )
actor_numel = get_model_numel(actor, strategy)
critic_numel = get_model_numel(critic, strategy)
initial_model_numel = get_model_numel(initial_model, strategy)
reward_model_numel = get_model_numel(reward_model, strategy)
- print_model_numel({
- 'Actor': actor_numel,
- 'Critic': critic_numel,
- 'Initial model': initial_model_numel,
- 'Reward model': reward_model_numel
- })
- performance_evaluator = PerformanceEvaluator(actor_numel,
- critic_numel,
- initial_model_numel,
- reward_model_numel,
- enable_grad_checkpoint=False,
- ignore_episodes=1)
-
- if args.strategy.startswith('colossalai'):
+ print_model_numel(
+ {
+ "Actor": actor_numel,
+ "Critic": critic_numel,
+ "Initial model": initial_model_numel,
+ "Reward model": reward_model_numel,
+ }
+ )
+ performance_evaluator = PerformanceEvaluator(
+ actor_numel,
+ critic_numel,
+ initial_model_numel,
+ reward_model_numel,
+ enable_grad_checkpoint=False,
+ ignore_episodes=1,
+ )
+
+ if args.strategy.startswith("colossalai"):
actor_optim = HybridAdam(actor.parameters(), lr=5e-6)
critic_optim = HybridAdam(critic.parameters(), lr=5e-6)
else:
actor_optim = Adam(actor.parameters(), lr=5e-6)
critic_optim = Adam(critic.parameters(), lr=5e-6)
- tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m')
+ tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
tokenizer.pad_token = tokenizer.eos_token
+ tokenizer.padding_side = "left"
(actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim))
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 256), device=torch.cuda.current_device())
- dataloader = DataLoader(random_prompts,
- batch_size=args.experience_batch_size,
- shuffle=True,
- collate_fn=preprocess_batch)
-
- trainer = PPOTrainer(strategy,
- actor,
- critic,
- reward_model,
- initial_model,
- actor_optim,
- critic_optim,
- ptx_coef=0,
- train_batch_size=args.train_batch_size,
- offload_inference_models=args.offload_inference_models,
- max_length=512,
- do_sample=True,
- temperature=1.0,
- top_k=50,
- use_cache=True,
- pad_token_id=tokenizer.pad_token_id,
- eos_token_id=tokenizer.eos_token_id,
- callbacks=[performance_evaluator])
-
- trainer.fit(prompt_dataloader=dataloader,
- pretrain_dataloader=None,
- num_episodes=args.num_episodes,
- num_update_steps=args.num_update_steps,
- num_collect_steps=args.num_collect_steps)
-
- print_rank_0(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB')
-
-
-if __name__ == '__main__':
+ dataloader = DataLoader(
+ random_prompts, batch_size=args.experience_batch_size, shuffle=True, collate_fn=preprocess_batch
+ )
+
+ trainer = PPOTrainer(
+ strategy,
+ actor,
+ critic,
+ reward_model,
+ initial_model,
+ actor_optim,
+ critic_optim,
+ tokenizer=tokenizer,
+ ptx_coef=0,
+ train_batch_size=args.train_batch_size,
+ offload_inference_models=args.offload_inference_models,
+ max_length=512,
+ do_sample=True,
+ temperature=1.0,
+ top_k=50,
+ use_cache=True,
+ callbacks=[performance_evaluator],
+ )
+
+ trainer.fit(
+ prompt_dataloader=dataloader,
+ pretrain_dataloader=None,
+ num_episodes=args.num_episodes,
+ num_update_steps=args.num_update_steps,
+ num_collect_steps=args.num_collect_steps,
+ )
+
+ print_rank_0(f"Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB")
+
+
+if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('--model', default='125m')
- parser.add_argument('--critic_model', default='125m')
- parser.add_argument('--strategy',
- choices=[
- 'ddp', 'colossalai_gemini', 'colossalai_gemini_cpu', 'colossalai_zero2',
- 'colossalai_zero2_cpu', 'colossalai_zero1', 'colossalai_zero1_cpu'
- ],
- default='ddp')
- parser.add_argument('--num_episodes', type=int, default=3)
- parser.add_argument('--num_collect_steps', type=int, default=8)
- parser.add_argument('--num_update_steps', type=int, default=1)
- parser.add_argument('--train_batch_size', type=int, default=8)
- parser.add_argument('--experience_batch_size', type=int, default=8)
- parser.add_argument('--lora_rank', type=int, default=0)
- parser.add_argument('--cuda_mem_frac', type=float, default=1.0)
- parser.add_argument('--offload_inference_models', action='store_true', default=False)
- parser.add_argument('--use_kernels', action='store_true', default=False)
+ parser.add_argument("--model", default="125m")
+ parser.add_argument("--critic_model", default="125m")
+ parser.add_argument(
+ "--strategy",
+ choices=[
+ "ddp",
+ "colossalai_gemini",
+ "colossalai_gemini_cpu",
+ "colossalai_zero2",
+ "colossalai_zero2_cpu",
+ "colossalai_zero1",
+ "colossalai_zero1_cpu",
+ ],
+ default="ddp",
+ )
+ parser.add_argument("--num_episodes", type=int, default=3)
+ parser.add_argument("--num_collect_steps", type=int, default=8)
+ parser.add_argument("--num_update_steps", type=int, default=1)
+ parser.add_argument("--train_batch_size", type=int, default=8)
+ parser.add_argument("--experience_batch_size", type=int, default=8)
+ parser.add_argument("--lora_rank", type=int, default=0)
+ parser.add_argument("--cuda_mem_frac", type=float, default=1.0)
+ parser.add_argument("--offload_inference_models", action="store_true", default=False)
+ parser.add_argument("--use_kernels", action="store_true", default=False)
args = parser.parse_args()
main(args)
diff --git a/applications/Chat/benchmarks/ray/1mmt_dummy.py b/applications/Chat/benchmarks/ray/1mmt_dummy.py
index 7fc990448805..98ace3869450 100644
--- a/applications/Chat/benchmarks/ray/1mmt_dummy.py
+++ b/applications/Chat/benchmarks/ray/1mmt_dummy.py
@@ -22,13 +22,13 @@
def get_free_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
- s.bind(('', 0))
+ s.bind(("", 0))
return s.getsockname()[1]
def get_local_ip():
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
- s.connect(('8.8.8.8', 80))
+ s.connect(("8.8.8.8", 80))
return s.getsockname()[0]
@@ -36,22 +36,25 @@ def main(args):
master_addr = str(get_local_ip())
# trainer_env_info
trainer_port = str(get_free_port())
- env_info_trainers = [{
- 'local_rank': '0',
- 'rank': str(rank),
- 'world_size': str(args.num_trainers),
- 'master_port': trainer_port,
- 'master_addr': master_addr
- } for rank in range(args.num_trainers)]
+ env_info_trainers = [
+ {
+ "local_rank": "0",
+ "rank": str(rank),
+ "world_size": str(args.num_trainers),
+ "master_port": trainer_port,
+ "master_addr": master_addr,
+ }
+ for rank in range(args.num_trainers)
+ ]
# maker_env_info
maker_port = str(get_free_port())
env_info_maker = {
- 'local_rank': '0',
- 'rank': '0',
- 'world_size': '1',
- 'master_port': maker_port,
- 'master_addr': master_addr
+ "local_rank": "0",
+ "rank": "0",
+ "world_size": "1",
+ "master_port": maker_port,
+ "master_addr": master_addr,
}
# configure tokenizer
@@ -63,21 +66,27 @@ def model_fn():
critic_cfg = AutoConfig.from_pretrained(args.critic_pretrain)
actor = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda()
critic = get_critic_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda()
- reward_model = get_reward_model_from_args(args.critic_model,
- config=critic_cfg).requires_grad_(False).half().cuda()
- if args.initial_model_quant_ckpt is not None and args.model == 'llama':
+ reward_model = (
+ get_reward_model_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda()
+ )
+ if args.initial_model_quant_ckpt is not None and args.model == "llama":
# quantize initial model
with low_resource_init(), no_init_weights():
initial_model = get_actor_from_args(args.model, config=actor_cfg)
- initial_model.model = llama_load_quant(initial_model.model, args.initial_model_quant_ckpt, args.quant_bits,
- args.quant_group_size).cuda().requires_grad_(False)
+ initial_model.model = (
+ llama_load_quant(
+ initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, args.quant_group_size
+ )
+ .cuda()
+ .requires_grad_(False)
+ )
else:
initial_model = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda()
return actor, critic, reward_model, initial_model
# configure Experience Maker
experience_holder_ref = ExperienceMakerHolder.options(name="maker0", num_gpus=1, max_concurrency=2).remote(
- detached_trainer_name_list=[f'trainer{i}' for i in range(args.num_trainers)],
+ detached_trainer_name_list=[f"trainer{i}" for i in range(args.num_trainers)],
strategy_fn=partial(get_strategy_from_args, args.maker_strategy),
model_fn=model_fn,
env_info=env_info_maker,
@@ -97,15 +106,18 @@ def model_fn():
def trainer_model_fn():
actor = get_actor_from_args(args.model, config=AutoConfig.from_pretrained(args.pretrain)).half().cuda()
- critic = get_critic_from_args(args.critic_model,
- config=AutoConfig.from_pretrained(args.critic_pretrain)).half().cuda()
+ critic = (
+ get_critic_from_args(args.critic_model, config=AutoConfig.from_pretrained(args.critic_pretrain))
+ .half()
+ .cuda()
+ )
return actor, critic
# configure Trainer
trainer_refs = [
DetachedPPOTrainer.options(name=f"trainer{i}", num_gpus=1, max_concurrency=2).remote(
experience_maker_holder_name_list=[
- f'maker{x}' for x in get_receivers_per_sender(i, args.num_trainers, 1, allow_idle_sender=True)
+ f"maker{x}" for x in get_receivers_per_sender(i, args.num_trainers, 1, allow_idle_sender=True)
],
strategy_fn=partial(get_strategy_from_args, args.trainer_strategy),
model_fn=trainer_model_fn,
@@ -114,7 +126,8 @@ def trainer_model_fn():
buffer_limit=16,
eval_performance=True,
debug=args.debug,
- ) for i, env_info_trainer in enumerate(env_info_trainers)
+ )
+ for i, env_info_trainer in enumerate(env_info_trainers)
]
dataset_size = args.experience_batch_size * 4
@@ -122,7 +135,7 @@ def trainer_model_fn():
def data_gen_fn():
input_ids = torch.randint(tokenizer.vocab_size, (256,), device=torch.cuda.current_device())
attn_mask = torch.ones_like(input_ids)
- return {'input_ids': input_ids, 'attention_mask': attn_mask}
+ return {"input_ids": input_ids, "attention_mask": attn_mask}
def build_dataloader(size):
dataset = [data_gen_fn() for _ in range(size)]
@@ -138,8 +151,10 @@ def build_dataloader(size):
wait_tasks = []
wait_tasks.append(
- experience_holder_ref.workingloop.remote(partial(build_dataloader, dataset_size),
- num_steps=args.experience_steps))
+ experience_holder_ref.workingloop.remote(
+ partial(build_dataloader, dataset_size), num_steps=args.experience_steps
+ )
+ )
total_steps = args.experience_batch_size * args.experience_steps // (args.num_trainers * args.train_batch_size)
for trainer_ref in trainer_refs:
@@ -148,31 +163,30 @@ def build_dataloader(size):
ray.get(wait_tasks)
-if __name__ == '__main__':
+if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('--num_trainers', type=int, default=1)
- parser.add_argument('--trainer_strategy',
- choices=[
- 'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu',
- 'colossalai_zero2_cpu'
- ],
- default='ddp')
- parser.add_argument('--maker_strategy', choices=['naive'], default='naive')
- parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
- parser.add_argument('--critic_model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
- parser.add_argument('--pretrain', type=str, default=None)
- parser.add_argument('--critic_pretrain', type=str, default=None)
- parser.add_argument('--experience_steps', type=int, default=4)
- parser.add_argument('--experience_batch_size', type=int, default=8)
- parser.add_argument('--train_epochs', type=int, default=1)
- parser.add_argument('--update_steps', type=int, default=2)
- parser.add_argument('--train_batch_size', type=int, default=8)
- parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
-
- parser.add_argument('--initial_model_quant_ckpt', type=str, default=None)
- parser.add_argument('--quant_bits', type=int, default=4)
- parser.add_argument('--quant_group_size', type=int, default=128)
- parser.add_argument('--debug', action='store_true')
+ parser.add_argument("--num_trainers", type=int, default=1)
+ parser.add_argument(
+ "--trainer_strategy",
+ choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_gemini_cpu", "colossalai_zero2_cpu"],
+ default="ddp",
+ )
+ parser.add_argument("--maker_strategy", choices=["naive"], default="naive")
+ parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
+ parser.add_argument("--critic_model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
+ parser.add_argument("--pretrain", type=str, default=None)
+ parser.add_argument("--critic_pretrain", type=str, default=None)
+ parser.add_argument("--experience_steps", type=int, default=4)
+ parser.add_argument("--experience_batch_size", type=int, default=8)
+ parser.add_argument("--train_epochs", type=int, default=1)
+ parser.add_argument("--update_steps", type=int, default=2)
+ parser.add_argument("--train_batch_size", type=int, default=8)
+ parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
+
+ parser.add_argument("--initial_model_quant_ckpt", type=str, default=None)
+ parser.add_argument("--quant_bits", type=int, default=4)
+ parser.add_argument("--quant_group_size", type=int, default=128)
+ parser.add_argument("--debug", action="store_true")
args = parser.parse_args()
ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)})
main(args)
diff --git a/applications/Chat/benchmarks/ray/mmmt_dummy.py b/applications/Chat/benchmarks/ray/mmmt_dummy.py
index ca1df22070fc..f8860f2979ee 100644
--- a/applications/Chat/benchmarks/ray/mmmt_dummy.py
+++ b/applications/Chat/benchmarks/ray/mmmt_dummy.py
@@ -22,13 +22,13 @@
def get_free_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
- s.bind(('', 0))
+ s.bind(("", 0))
return s.getsockname()[1]
def get_local_ip():
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
- s.connect(('8.8.8.8', 80))
+ s.connect(("8.8.8.8", 80))
return s.getsockname()[0]
@@ -36,23 +36,29 @@ def main(args):
master_addr = str(get_local_ip())
# trainer_env_info
trainer_port = str(get_free_port())
- env_info_trainers = [{
- 'local_rank': '0',
- 'rank': str(rank),
- 'world_size': str(args.num_trainers),
- 'master_port': trainer_port,
- 'master_addr': master_addr
- } for rank in range(args.num_trainers)]
+ env_info_trainers = [
+ {
+ "local_rank": "0",
+ "rank": str(rank),
+ "world_size": str(args.num_trainers),
+ "master_port": trainer_port,
+ "master_addr": master_addr,
+ }
+ for rank in range(args.num_trainers)
+ ]
# maker_env_info
maker_port = str(get_free_port())
- env_info_makers = [{
- 'local_rank': '0',
- 'rank': str(rank),
- 'world_size': str(args.num_makers),
- 'master_port': maker_port,
- 'master_addr': master_addr
- } for rank in range(args.num_makers)]
+ env_info_makers = [
+ {
+ "local_rank": "0",
+ "rank": str(rank),
+ "world_size": str(args.num_makers),
+ "master_port": maker_port,
+ "master_addr": master_addr,
+ }
+ for rank in range(args.num_makers)
+ ]
# configure tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.pretrain)
@@ -63,14 +69,20 @@ def model_fn():
critic_cfg = AutoConfig.from_pretrained(args.critic_pretrain)
actor = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda()
critic = get_critic_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda()
- reward_model = get_reward_model_from_args(args.critic_model,
- config=critic_cfg).requires_grad_(False).half().cuda()
- if args.initial_model_quant_ckpt is not None and args.model == 'llama':
+ reward_model = (
+ get_reward_model_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda()
+ )
+ if args.initial_model_quant_ckpt is not None and args.model == "llama":
# quantize initial model
with low_resource_init(), no_init_weights():
initial_model = get_actor_from_args(args.model, config=actor_cfg)
- initial_model.model = llama_load_quant(initial_model.model, args.initial_model_quant_ckpt, args.quant_bits,
- args.quant_group_size).cuda().requires_grad_(False)
+ initial_model.model = (
+ llama_load_quant(
+ initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, args.quant_group_size
+ )
+ .cuda()
+ .requires_grad_(False)
+ )
else:
initial_model = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda()
return actor, critic, reward_model, initial_model
@@ -79,7 +91,7 @@ def model_fn():
experience_holder_refs = [
ExperienceMakerHolder.options(name=f"maker{i}", num_gpus=1, max_concurrency=2).remote(
detached_trainer_name_list=[
- f'trainer{x}'
+ f"trainer{x}"
for x in get_receivers_per_sender(i, args.num_makers, args.num_trainers, allow_idle_sender=False)
],
strategy_fn=partial(get_strategy_from_args, args.maker_strategy),
@@ -103,8 +115,11 @@ def model_fn():
def trainer_model_fn():
actor = get_actor_from_args(args.model, config=AutoConfig.from_pretrained(args.pretrain)).half().cuda()
- critic = get_critic_from_args(args.critic_model,
- config=AutoConfig.from_pretrained(args.critic_pretrain)).half().cuda()
+ critic = (
+ get_critic_from_args(args.critic_model, config=AutoConfig.from_pretrained(args.critic_pretrain))
+ .half()
+ .cuda()
+ )
return actor, critic
# configure Trainer
@@ -130,7 +145,7 @@ def trainer_model_fn():
def data_gen_fn():
input_ids = torch.randint(tokenizer.vocab_size, (256,), device=torch.cuda.current_device())
attn_mask = torch.ones_like(input_ids)
- return {'input_ids': input_ids, 'attention_mask': attn_mask}
+ return {"input_ids": input_ids, "attention_mask": attn_mask}
def build_dataloader(size):
dataset = [data_gen_fn() for _ in range(size)]
@@ -147,43 +162,48 @@ def build_dataloader(size):
for experience_holder_ref in experience_holder_refs:
wait_tasks.append(
- experience_holder_ref.workingloop.remote(partial(build_dataloader, dataset_size),
- num_steps=args.experience_steps))
+ experience_holder_ref.workingloop.remote(
+ partial(build_dataloader, dataset_size), num_steps=args.experience_steps
+ )
+ )
- total_steps = args.experience_batch_size * args.experience_steps * \
- args.num_makers // (args.num_trainers * args.train_batch_size)
+ total_steps = (
+ args.experience_batch_size
+ * args.experience_steps
+ * args.num_makers
+ // (args.num_trainers * args.train_batch_size)
+ )
for trainer_ref in trainer_refs:
wait_tasks.append(trainer_ref.fit.remote(total_steps, args.update_steps, args.train_epochs))
ray.get(wait_tasks)
-if __name__ == '__main__':
+if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('--num_makers', type=int, default=1)
- parser.add_argument('--num_trainers', type=int, default=1)
- parser.add_argument('--trainer_strategy',
- choices=[
- 'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu',
- 'colossalai_zero2_cpu'
- ],
- default='ddp')
- parser.add_argument('--maker_strategy', choices=['naive'], default='naive')
- parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
- parser.add_argument('--critic_model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
- parser.add_argument('--pretrain', type=str, default=None)
- parser.add_argument('--critic_pretrain', type=str, default=None)
- parser.add_argument('--experience_steps', type=int, default=4)
- parser.add_argument('--experience_batch_size', type=int, default=8)
- parser.add_argument('--train_epochs', type=int, default=1)
- parser.add_argument('--update_steps', type=int, default=2)
- parser.add_argument('--train_batch_size', type=int, default=8)
- parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
-
- parser.add_argument('--initial_model_quant_ckpt', type=str, default=None)
- parser.add_argument('--quant_bits', type=int, default=4)
- parser.add_argument('--quant_group_size', type=int, default=128)
- parser.add_argument('--debug', action='store_true')
+ parser.add_argument("--num_makers", type=int, default=1)
+ parser.add_argument("--num_trainers", type=int, default=1)
+ parser.add_argument(
+ "--trainer_strategy",
+ choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_gemini_cpu", "colossalai_zero2_cpu"],
+ default="ddp",
+ )
+ parser.add_argument("--maker_strategy", choices=["naive"], default="naive")
+ parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
+ parser.add_argument("--critic_model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
+ parser.add_argument("--pretrain", type=str, default=None)
+ parser.add_argument("--critic_pretrain", type=str, default=None)
+ parser.add_argument("--experience_steps", type=int, default=4)
+ parser.add_argument("--experience_batch_size", type=int, default=8)
+ parser.add_argument("--train_epochs", type=int, default=1)
+ parser.add_argument("--update_steps", type=int, default=2)
+ parser.add_argument("--train_batch_size", type=int, default=8)
+ parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
+
+ parser.add_argument("--initial_model_quant_ckpt", type=str, default=None)
+ parser.add_argument("--quant_bits", type=int, default=4)
+ parser.add_argument("--quant_group_size", type=int, default=128)
+ parser.add_argument("--debug", action="store_true")
args = parser.parse_args()
ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)})
main(args)
diff --git a/applications/Chat/coati/dataset/__init__.py b/applications/Chat/coati/dataset/__init__.py
index bd4e5460d11e..599b57609775 100644
--- a/applications/Chat/coati/dataset/__init__.py
+++ b/applications/Chat/coati/dataset/__init__.py
@@ -4,7 +4,10 @@
from .utils import is_rank_0
__all__ = [
- 'RmStaticDataset', 'HhRlhfDataset',
- 'SFTDataset', 'SupervisedDataset',
- 'PromptDataset', 'is_rank_0',
+ "RmStaticDataset",
+ "HhRlhfDataset",
+ "SFTDataset",
+ "SupervisedDataset",
+ "PromptDataset",
+ "is_rank_0",
]
diff --git a/applications/Chat/coati/dataset/conversation.py b/applications/Chat/coati/dataset/conversation.py
index 465fa867c7ab..f2180d96b0d3 100644
--- a/applications/Chat/coati/dataset/conversation.py
+++ b/applications/Chat/coati/dataset/conversation.py
@@ -49,7 +49,7 @@ def append_message(self, role, message):
def to_gradio_chatbot(self):
ret = []
- for i, (role, msg) in enumerate(self.messages[self.offset:]):
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
ret.append([msg, None])
else:
@@ -57,12 +57,14 @@ def to_gradio_chatbot(self):
return ret
def copy(self):
- return Conversation(system=self.system,
- roles=self.roles,
- messages=[[x, y] for x, y in self.messages],
- offset=self.offset,
- sep_style=self.sep_style,
- sep=self.sep)
+ return Conversation(
+ system=self.system,
+ roles=self.roles,
+ messages=[[x, y] for x, y in self.messages],
+ offset=self.offset,
+ sep_style=self.sep_style,
+ sep=self.sep,
+ )
def dict(self):
return {
@@ -70,7 +72,7 @@ def dict(self):
"roles": self.roles,
"messages": self.messages,
"offset": self.offset,
- "sep": self.sep
+ "sep": self.sep,
}
diff --git a/applications/Chat/coati/dataset/prompt_dataset.py b/applications/Chat/coati/dataset/prompt_dataset.py
index 2c953fffa513..17120e6064b5 100644
--- a/applications/Chat/coati/dataset/prompt_dataset.py
+++ b/applications/Chat/coati/dataset/prompt_dataset.py
@@ -13,11 +13,13 @@
class PromptDataset(Dataset):
"""Dataset for supervised fine-tuning."""
- def __init__(self,
- data_path: str,
- tokenizer: transformers.PreTrainedTokenizer,
- max_datasets_size: int = None,
- max_length: int = 96):
+ def __init__(
+ self,
+ data_path: str,
+ tokenizer: transformers.PreTrainedTokenizer,
+ max_datasets_size: int = None,
+ max_length: int = 96,
+ ):
super(PromptDataset, self).__init__()
self.keyed_prompt = defaultdict(list)
self.logger = get_dist_logger()
@@ -30,11 +32,9 @@ def __init__(self,
list_data_dict = list_data_dict[:max_datasets_size]
instructions = [data_dict["instruction"] for data_dict in list_data_dict]
- tokens = tokenizer(instructions,
- return_tensors='pt',
- max_length=max_length,
- padding='max_length',
- truncation=True)
+ tokens = tokenizer(
+ instructions, return_tensors="pt", max_length=max_length, padding="max_length", truncation=True
+ )
for k, tensor in tokens.items():
self.keyed_prompt[k] = tensor.to(torch.cuda.current_device()).unbind()
diff --git a/applications/Chat/coati/dataset/reward_dataset.py b/applications/Chat/coati/dataset/reward_dataset.py
index 3c4ec8b214bb..3afcd7b69238 100644
--- a/applications/Chat/coati/dataset/reward_dataset.py
+++ b/applications/Chat/coati/dataset/reward_dataset.py
@@ -20,44 +20,31 @@ class RmStaticDataset(Dataset):
def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
super().__init__()
- self.end_token = tokenizer.eos_token \
- if special_token is None else special_token
-
- chosen = [
- data["prompt"] + data["chosen"] + self.end_token
- for data in tqdm(dataset, disable=not is_rank_0())
- ]
- chosen_token = tokenizer(chosen,
- max_length=max_length,
- padding="max_length",
- truncation=True,
- return_tensors="pt")
- self.chosen = {
- "input_ids": chosen_token["input_ids"],
- "attention_mask": chosen_token["attention_mask"]
- }
-
- reject = [
- data["prompt"] + data["rejected"] + self.end_token
- for data in tqdm(dataset, disable=not is_rank_0())
- ]
- reject_token = tokenizer(reject,
- max_length=max_length,
- padding="max_length",
- truncation=True,
- return_tensors="pt")
- self.reject = {
- "input_ids": reject_token["input_ids"],
- "attention_mask": reject_token["attention_mask"]
- }
+ self.end_token = tokenizer.eos_token if special_token is None else special_token
+
+ chosen = [data["prompt"] + data["chosen"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
+ chosen_token = tokenizer(
+ chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ self.chosen = {"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]}
+
+ reject = [data["prompt"] + data["rejected"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
+ reject_token = tokenizer(
+ reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ self.reject = {"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]}
def __len__(self):
length = self.chosen["input_ids"].shape[0]
return length
def __getitem__(self, idx):
- return self.chosen["input_ids"][idx], self.chosen["attention_mask"][idx], \
- self.reject["input_ids"][idx], self.reject["attention_mask"][idx]
+ return (
+ self.chosen["input_ids"][idx],
+ self.chosen["attention_mask"][idx],
+ self.reject["input_ids"][idx],
+ self.reject["attention_mask"][idx],
+ )
# Anthropic/hh-rlhf
@@ -74,41 +61,28 @@ class HhRlhfDataset(Dataset):
def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
super().__init__()
- self.end_token = tokenizer.eos_token \
- if special_token is None else special_token
-
- chosen = [
- data["chosen"] + self.end_token
- for data in tqdm(dataset, disable=not is_rank_0())
- ]
- chosen_token = tokenizer(chosen,
- max_length=max_length,
- padding="max_length",
- truncation=True,
- return_tensors="pt")
- self.chosen = {
- "input_ids": chosen_token["input_ids"],
- "attention_mask": chosen_token["attention_mask"]
- }
-
- reject = [
- data["rejected"] + self.end_token
- for data in tqdm(dataset, disable=not is_rank_0())
- ]
- reject_token = tokenizer(reject,
- max_length=max_length,
- padding="max_length",
- truncation=True,
- return_tensors="pt")
- self.reject = {
- "input_ids": reject_token["input_ids"],
- "attention_mask": reject_token["attention_mask"]
- }
+ self.end_token = tokenizer.eos_token if special_token is None else special_token
+
+ chosen = [data["chosen"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
+ chosen_token = tokenizer(
+ chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ self.chosen = {"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]}
+
+ reject = [data["rejected"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
+ reject_token = tokenizer(
+ reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ self.reject = {"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]}
def __len__(self):
length = self.chosen["input_ids"].shape[0]
return length
def __getitem__(self, idx):
- return self.chosen["input_ids"][idx], self.chosen["attention_mask"][idx], \
- self.reject["input_ids"][idx], self.reject["attention_mask"][idx]
+ return (
+ self.chosen["input_ids"][idx],
+ self.chosen["attention_mask"][idx],
+ self.reject["input_ids"][idx],
+ self.reject["attention_mask"][idx],
+ )
diff --git a/applications/Chat/coati/dataset/sft_dataset.py b/applications/Chat/coati/dataset/sft_dataset.py
index 2959d3fac81c..c0e257f54a07 100644
--- a/applications/Chat/coati/dataset/sft_dataset.py
+++ b/applications/Chat/coati/dataset/sft_dataset.py
@@ -13,13 +13,14 @@
# limitations under the License.
import copy
-from typing import Dict, Sequence, Tuple
+from typing import Dict, Optional, Sequence, Tuple
import torch
+from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import PreTrainedTokenizer
-from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
+
from colossalai.logging import get_dist_logger
from .utils import is_rank_0, jload
@@ -28,33 +29,35 @@
IGNORE_INDEX = -100
PROMPT_DICT = {
- "prompt_input": ("Below is an instruction that describes a task, paired with an input that provides further context. "
- "Write a response that appropriately completes the request.\n\n"
- "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"),
- "prompt_no_input": ("Below is an instruction that describes a task. "
- "Write a response that appropriately completes the request.\n\n"
- "### Instruction:\n{instruction}\n\n### Response:"),
+ "prompt_input": (
+ "Below is an instruction that describes a task, paired with an input that provides further context. "
+ "Write a response that appropriately completes the request.\n\n"
+ "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
+ ),
+ "prompt_no_input": (
+ "Below is an instruction that describes a task. "
+ "Write a response that appropriately completes the request.\n\n"
+ "### Instruction:\n{instruction}\n\n### Response:"
+ ),
}
-def _preprocess(sources: Sequence[str],
- targets: Sequence[str],
- tokenizer: PreTrainedTokenizer,
- max_length: int,
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+def _preprocess(
+ sources: Sequence[str],
+ targets: Sequence[str],
+ tokenizer: PreTrainedTokenizer,
+ max_length: int,
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Preprocess the data by tokenizing."""
sequences = [s + t for s, t in zip(sources, targets)]
- sequences_token = tokenizer(sequences,
- max_length=max_length,
- padding="max_length",
- truncation=True,
- return_tensors="pt")
- sources_token = tokenizer(sources,
- max_length=max_length,
- padding="max_length",
- truncation=True,
- return_tensors="pt")
-
+ sequences_token = tokenizer(
+ sequences, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ sources_token = tokenizer(
+ sources, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+
+ assert sequences_token["attention_mask"].dim() == 2, "seq2seq model should be preprocessed differently"
labels = copy.deepcopy(sequences_token["input_ids"])
for i in range(labels.shape[0]):
source_len = sources_token["attention_mask"][i].sum().item()
@@ -62,25 +65,27 @@ def _preprocess(sources: Sequence[str],
if tokenizer.padding_side == "right":
# |prompt|completion|eos|pad|
labels[i][:source_len] = IGNORE_INDEX
+ labels[i][-pad_len:] = IGNORE_INDEX
elif tokenizer.padding_side == "left":
# |pad|prompt|completion|eos|
- labels[i][pad_len:pad_len + source_len] = IGNORE_INDEX
+ labels[i][: pad_len + source_len] = IGNORE_INDEX
else:
raise RuntimeError()
return sequences_token["input_ids"], labels, sequences_token["attention_mask"]
-def _preprocess_chatglm(sources: Sequence[str],
- targets: Sequence[str],
- tokenizer: PreTrainedTokenizer,
- max_length: int,
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+def _preprocess_chatglm(
+ sources: Sequence[str],
+ targets: Sequence[str],
+ tokenizer: PreTrainedTokenizer,
+ max_length: int,
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Preprocess the data by tokenizing.
None for attention mask, ChatGLM will calculate attention mask according to input ids
"""
-
+
labels = []
input_ids = []
for source, target in zip(sources, targets):
@@ -90,16 +95,16 @@ def _preprocess_chatglm(sources: Sequence[str],
# truncate
sp_token_list = [tokenizer.gmask_token_id, tokenizer.bos_token_id]
truncate_length = max(0, len(input_id) - max_length)
- input_id = input_id[truncate_length: ]
+ input_id = input_id[truncate_length:]
if truncate_length == len(source_id) + 1:
- input_id = sp_token_list + input_id[1: ]
+ input_id = sp_token_list + input_id[1:]
elif truncate_length > len(source_id) + 1:
- input_id = sp_token_list + input_id[2: ]
-
+ input_id = sp_token_list + input_id[2:]
+
context_length = input_id.index(tokenizer.bos_token_id)
mask_position = context_length - 1
- label = [IGNORE_INDEX] * context_length + input_id[mask_position+1:]
-
+ label = [IGNORE_INDEX] * context_length + input_id[mask_position + 1 :]
+
pad_len = max_length - len(input_id)
input_id = input_id + [tokenizer.pad_token_id] * pad_len
input_ids.append(input_id)
@@ -117,25 +122,22 @@ class SFTDataset(Dataset):
max_length: max length of input
"""
- def __init__(self,
- dataset: Dict,
- tokenizer: PreTrainedTokenizer,
- max_length: int = 512
- ) -> None:
+ def __init__(self, dataset: Dict, tokenizer: PreTrainedTokenizer, max_length: int = 512) -> None:
super().__init__()
self.input_ids = []
sources = [data["prompt"] for data in dataset]
- targets = [
- data["completion"] + tokenizer.eos_token
- for data in tqdm(dataset, disable=not is_rank_0())
- ]
+ targets = [data["completion"] + tokenizer.eos_token for data in tqdm(dataset, disable=not is_rank_0())]
+
+ logger.info("Tokenizing inputs... This may take some time...")
if isinstance(tokenizer, ChatGLMTokenizer):
- self.input_ids, self.labels, self.attention_mask = \
- _preprocess_chatglm(sources, targets, tokenizer, max_length)
+ self.input_ids, self.labels, self.attention_mask = _preprocess_chatglm(
+ sources, targets, tokenizer, max_length
+ )
else:
- self.input_ids, self.labels, self.attention_mask = \
- _preprocess(sources, targets, tokenizer, max_length)
+ self.input_ids, self.labels, self.attention_mask = _preprocess(sources, targets, tokenizer, max_length)
+
+ logger.info("Loaded dataset.")
def __len__(self):
length = self.input_ids.shape[0]
@@ -143,22 +145,21 @@ def __len__(self):
def __getitem__(self, idx):
if self.attention_mask is not None:
- return dict(input_ids=self.input_ids[idx],
- labels=self.labels[idx],
- attention_mask=self.attention_mask[idx])
+ return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx])
else:
- return dict(input_ids=self.input_ids[idx],
- labels=self.labels[idx])
+ return dict(input_ids=self.input_ids[idx], labels=self.labels[idx])
class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
- def __init__(self,
- data_path: str,
- tokenizer: PreTrainedTokenizer,
- max_datasets_size: int = None,
- max_length: int = 512):
+ def __init__(
+ self,
+ data_path: str,
+ tokenizer: PreTrainedTokenizer,
+ max_datasets_size: Optional[int] = None,
+ max_length: int = 512,
+ ):
super().__init__()
logger.info("Loading data...")
list_data_dict = jload(data_path)
@@ -174,18 +175,17 @@ def __init__(self,
prompt_input.format_map(example) if "input" in example else prompt_no_input.format_map(example)
for example in list_data_dict
]
- targets = [
- example['output'] + tokenizer.eos_token
- for example in list_data_dict
- ]
+ targets = [example["output"] + tokenizer.eos_token for example in list_data_dict]
logger.info("Tokenizing inputs... This may take some time...")
if isinstance(tokenizer, ChatGLMTokenizer):
- self.input_ids, self.labels, self.attention_mask = \
- _preprocess_chatglm(sources, targets, tokenizer, max_length)
+ self.input_ids, self.labels, self.attention_mask = _preprocess_chatglm(
+ sources, targets, tokenizer, max_length
+ )
else:
- self.input_ids, self.labels, self.attention_mask = \
- _preprocess(sources, targets, tokenizer, max_length)
+ self.input_ids, self.labels, self.attention_mask = _preprocess(sources, targets, tokenizer, max_length)
+
+ logger.info("Loaded dataset.")
def __len__(self):
length = self.input_ids.shape[0]
@@ -193,9 +193,6 @@ def __len__(self):
def __getitem__(self, idx):
if self.attention_mask is not None:
- return dict(input_ids=self.input_ids[idx],
- labels=self.labels[idx],
- attention_mask=self.attention_mask[idx])
+ return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx])
else:
- return dict(input_ids=self.input_ids[idx],
- labels=self.labels[idx])
+ return dict(input_ids=self.input_ids[idx], labels=self.labels[idx])
diff --git a/applications/Chat/coati/experience_buffer/__init__.py b/applications/Chat/coati/experience_buffer/__init__.py
index c0188dc4a471..f2a48d0a3b20 100644
--- a/applications/Chat/coati/experience_buffer/__init__.py
+++ b/applications/Chat/coati/experience_buffer/__init__.py
@@ -1,4 +1,4 @@
from .base import ExperienceBuffer
from .naive import NaiveExperienceBuffer
-__all__ = ['ExperienceBuffer', 'NaiveExperienceBuffer']
+__all__ = ["ExperienceBuffer", "NaiveExperienceBuffer"]
diff --git a/applications/Chat/coati/experience_buffer/base.py b/applications/Chat/coati/experience_buffer/base.py
index 9ccdc935d506..7047785308f3 100644
--- a/applications/Chat/coati/experience_buffer/base.py
+++ b/applications/Chat/coati/experience_buffer/base.py
@@ -7,9 +7,9 @@
class ExperienceBuffer(ABC):
"""Experience buffer base class. It stores experience.
- Args:
- sample_batch_size (int): Batch size when sampling.
- limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0.
+ Args:
+ sample_batch_size (int): Batch size when sampling.
+ limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0.
"""
def __init__(self, sample_batch_size: int, limit: int = 0) -> None:
diff --git a/applications/Chat/coati/experience_buffer/naive.py b/applications/Chat/coati/experience_buffer/naive.py
index bd5213b38993..d47b67dbe713 100644
--- a/applications/Chat/coati/experience_buffer/naive.py
+++ b/applications/Chat/coati/experience_buffer/naive.py
@@ -1,4 +1,5 @@
import random
+import warnings
from typing import List
import torch
@@ -11,28 +12,30 @@
class NaiveExperienceBuffer(ExperienceBuffer):
"""Naive experience buffer class. It stores experience.
- Args:
- sample_batch_size (int): Batch size when sampling.
- limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0.
- cpu_offload (bool, optional): Whether to offload experience to cpu when sampling. Defaults to True.
+ Args:
+ sample_batch_size (int): Batch size when sampling.
+ limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0.
+ cpu_offload (bool, optional): Whether to offload experience to cpu when sampling. Defaults to True.
"""
def __init__(self, sample_batch_size: int, limit: int = 0, cpu_offload: bool = True) -> None:
super().__init__(sample_batch_size, limit)
self.cpu_offload = cpu_offload
- self.target_device = torch.device(f'cuda:{torch.cuda.current_device()}')
+ self.target_device = torch.device(f"cuda:{torch.cuda.current_device()}")
# TODO(ver217): add prefetch
self.items: List[BufferItem] = []
@torch.no_grad()
def append(self, experience: Experience) -> None:
if self.cpu_offload:
- experience.to_device(torch.device('cpu'))
+ experience.to_device(torch.device("cpu"))
items = split_experience_batch(experience)
self.items.extend(items)
+
if self.limit > 0:
samples_to_remove = len(self.items) - self.limit
if samples_to_remove > 0:
+ warnings.warn(f"Experience buffer is full. Removing {samples_to_remove} samples.")
self.items = self.items[samples_to_remove:]
def clear(self) -> None:
diff --git a/applications/Chat/coati/experience_buffer/utils.py b/applications/Chat/coati/experience_buffer/utils.py
index c2a34212e2f4..baedbebd184f 100644
--- a/applications/Chat/coati/experience_buffer/utils.py
+++ b/applications/Chat/coati/experience_buffer/utils.py
@@ -21,6 +21,7 @@ class BufferItem:
"A" is the number of actions.
"""
+
sequences: torch.Tensor
action_log_probs: torch.Tensor
values: torch.Tensor
@@ -33,8 +34,7 @@ class BufferItem:
def split_experience_batch(experience: Experience) -> List[BufferItem]:
batch_size = experience.sequences.size(0)
batch_kwargs = [{} for _ in range(batch_size)]
- keys = ('sequences', 'action_log_probs', 'values',
- 'reward', 'advantages', 'attention_mask', 'action_mask')
+ keys = ("sequences", "action_log_probs", "values", "reward", "advantages", "attention_mask", "action_mask")
for key in keys:
value = getattr(experience, key)
if isinstance(value, torch.Tensor):
@@ -49,22 +49,21 @@ def split_experience_batch(experience: Experience) -> List[BufferItem]:
return items
-def _zero_pad_sequences(sequences: List[torch.Tensor], side: str = 'left') -> torch.Tensor:
- assert side in ('left', 'right')
+def _zero_pad_sequences(sequences: List[torch.Tensor], side: str = "left") -> torch.Tensor:
+ assert side in ("left", "right")
max_len = max(seq.size(0) for seq in sequences)
padded_sequences = []
for seq in sequences:
pad_len = max_len - seq.size(0)
- padding = (pad_len, 0) if side == 'left' else (0, pad_len)
+ padding = (pad_len, 0) if side == "left" else (0, pad_len)
padded_sequences.append(F.pad(seq, padding))
return torch.stack(padded_sequences, dim=0)
def make_experience_batch(items: List[BufferItem]) -> Experience:
kwargs = {}
- to_pad_keys = set(('action_log_probs', 'action_mask'))
- keys = ('sequences', 'action_log_probs', 'values',
- 'reward', 'advantages', 'attention_mask', 'action_mask')
+ to_pad_keys = set(("action_log_probs", "action_mask"))
+ keys = ("sequences", "action_log_probs", "values", "reward", "advantages", "attention_mask", "action_mask")
for key in keys:
vals = [getattr(item, key) for item in items]
if key in to_pad_keys:
diff --git a/applications/Chat/coati/experience_maker/__init__.py b/applications/Chat/coati/experience_maker/__init__.py
index 39ca7576b227..06452292e77c 100644
--- a/applications/Chat/coati/experience_maker/__init__.py
+++ b/applications/Chat/coati/experience_maker/__init__.py
@@ -1,4 +1,4 @@
from .base import Experience, ExperienceMaker
from .naive import NaiveExperienceMaker
-__all__ = ['Experience', 'ExperienceMaker', 'NaiveExperienceMaker']
+__all__ = ["Experience", "ExperienceMaker", "NaiveExperienceMaker"]
diff --git a/applications/Chat/coati/experience_maker/base.py b/applications/Chat/coati/experience_maker/base.py
index ff75852576c8..0731f6e0f97f 100644
--- a/applications/Chat/coati/experience_maker/base.py
+++ b/applications/Chat/coati/experience_maker/base.py
@@ -3,14 +3,13 @@
from typing import Optional
import torch
-import torch.nn as nn
-from coati.models.base import Actor
+from coati.models.base import Actor, Critic, RewardModel
@dataclass
class Experience:
"""Experience is a batch of data.
- These data should have the the sequence length and number of actions.
+ These data should have the sequence length and number of actions.
Left padding for sequences is applied.
Shapes of each tensor:
@@ -24,6 +23,7 @@ class Experience:
"A" is the number of actions.
"""
+
sequences: torch.Tensor
action_log_probs: torch.Tensor
values: torch.Tensor
@@ -58,20 +58,13 @@ def pin_memory(self):
class ExperienceMaker(ABC):
-
- def __init__(self,
- actor: Actor,
- critic: nn.Module,
- reward_model: nn.Module,
- initial_model: Actor,
- kl_coef: float = 0.1) -> None:
+ def __init__(self, actor: Actor, critic: Critic, reward_model: RewardModel, initial_model: Actor) -> None:
super().__init__()
self.actor = actor
self.critic = critic
self.reward_model = reward_model
self.initial_model = initial_model
- self.kl_coef = kl_coef
@abstractmethod
- def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experience:
+ def make_experience(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **generate_kwargs) -> Experience:
pass
diff --git a/applications/Chat/coati/experience_maker/naive.py b/applications/Chat/coati/experience_maker/naive.py
index 496f8ab445fc..941e1994b148 100644
--- a/applications/Chat/coati/experience_maker/naive.py
+++ b/applications/Chat/coati/experience_maker/naive.py
@@ -1,7 +1,9 @@
import torch
import torch.nn.functional as F
+from coati.models.base import Actor, Critic, RewardModel
from coati.models.generation import generate
from coati.models.utils import calc_action_log_probs, compute_reward
+from transformers import PreTrainedTokenizer
from .base import Experience, ExperienceMaker
@@ -11,6 +13,19 @@ class NaiveExperienceMaker(ExperienceMaker):
Naive experience maker.
"""
+ def __init__(
+ self,
+ actor: Actor,
+ critic: Critic,
+ reward_model: RewardModel,
+ initial_model: Actor,
+ tokenizer: PreTrainedTokenizer,
+ kl_coef: float = 0.1,
+ ) -> None:
+ super().__init__(actor, critic, reward_model, initial_model)
+ self.tokenizer = tokenizer
+ self.kl_coef = kl_coef
+
@torch.no_grad()
def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experience:
self.actor.eval()
@@ -19,33 +34,32 @@ def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experie
self.reward_model.eval()
# generate sequences
- sequences = generate(self.actor, input_ids, **generate_kwargs)
+ sequences = generate(self.actor, input_ids, self.tokenizer, **generate_kwargs)
# calculate auxiliary tensors
attention_mask = None
- pad_token_id = generate_kwargs.get('pad_token_id', None)
+ pad_token_id = self.tokenizer.pad_token_id
if pad_token_id is not None:
- attention_mask = sequences.not_equal(pad_token_id)\
- .to(dtype=torch.long, device=sequences.device)
+ attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
input_len = input_ids.size(1)
- eos_token_id = generate_kwargs.get('eos_token_id', None)
+ eos_token_id = self.tokenizer.eos_token_id
if eos_token_id is None:
action_mask = torch.ones_like(sequences, dtype=torch.bool)
else:
# left padding may be applied, only mask action
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
- action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
+ action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
action_mask[:, :input_len] = False
action_mask = action_mask[:, 1:]
- action_mask = action_mask[:, -(sequences.size(1) - input_len):]
+ action_mask = action_mask[:, -(sequences.size(1) - input_len) :]
num_actions = action_mask.size(1)
- actor_output = self.actor(sequences, attention_mask)
+ actor_output = self.actor(sequences, attention_mask)["logits"]
action_log_probs = calc_action_log_probs(actor_output, sequences, num_actions)
- base_model_output = self.initial_model(sequences, attention_mask)
+ base_model_output = self.initial_model(sequences, attention_mask)["logits"]
base_action_log_probs = calc_action_log_probs(base_model_output, sequences, num_actions)
- value = self.critic(sequences, action_mask, attention_mask)
+ value = self.critic(sequences, attention_mask)
r = self.reward_model(sequences, attention_mask)
reward = compute_reward(r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask)
diff --git a/applications/Chat/coati/kernels/__init__.py b/applications/Chat/coati/kernels/__init__.py
index 230eedf7ecba..96d40c7c4709 100644
--- a/applications/Chat/coati/kernels/__init__.py
+++ b/applications/Chat/coati/kernels/__init__.py
@@ -1,6 +1,6 @@
from .wrapper import convert_to_xformer_model, recover_from_xformer_model
__all__ = [
- 'convert_to_xformer_model',
- 'recover_from_xformer_model',
+ "convert_to_xformer_model",
+ "recover_from_xformer_model",
]
diff --git a/applications/Chat/coati/kernels/opt_attn.py b/applications/Chat/coati/kernels/opt_attn.py
index e99f9c2247d1..d1eb139187f3 100644
--- a/applications/Chat/coati/kernels/opt_attn.py
+++ b/applications/Chat/coati/kernels/opt_attn.py
@@ -21,11 +21,12 @@ def forward(
output_attentions: bool = False,
) -> Tuple[Tensor, Optional[Tensor], Optional[Tuple[Tensor]]]:
if not self.training:
- return super().forward(hidden_states, key_value_states, past_key_value, attention_mask, layer_head_mask,
- output_attentions)
+ return super().forward(
+ hidden_states, key_value_states, past_key_value, attention_mask, layer_head_mask, output_attentions
+ )
"""Input shape: Batch x Time x Channel"""
- assert layer_head_mask is None, 'Xformers attention does not support layer_head_mask'
- assert not output_attentions, 'Xformers attention does not support output_attentions'
+ assert layer_head_mask is None, "Xformers attention does not support layer_head_mask"
+ assert not output_attentions, "Xformers attention does not support output_attentions"
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
@@ -69,12 +70,14 @@ def forward(
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
- attn_output = xops.memory_efficient_attention(query_states,
- key_states,
- value_states,
- attn_bias=xops.LowerTriangularMask(),
- p=self.dropout if self.training else 0.0,
- scale=self.scaling)
+ attn_output = xops.memory_efficient_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_bias=xops.LowerTriangularMask(),
+ p=self.dropout if self.training else 0.0,
+ scale=self.scaling,
+ )
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned across GPUs when using tensor-parallelism.
diff --git a/applications/Chat/coati/models/__init__.py b/applications/Chat/coati/models/__init__.py
index 0a296a863756..ad4a525b4af2 100644
--- a/applications/Chat/coati/models/__init__.py
+++ b/applications/Chat/coati/models/__init__.py
@@ -3,6 +3,13 @@
from .loss import LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
__all__ = [
- 'Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'LogSigLoss', 'LogExpLoss',
- 'LoRAModule', 'convert_to_lora_module'
+ "Actor",
+ "Critic",
+ "RewardModel",
+ "PolicyLoss",
+ "ValueLoss",
+ "LogSigLoss",
+ "LogExpLoss",
+ "LoRAModule",
+ "convert_to_lora_module",
]
diff --git a/applications/Chat/coati/models/base/__init__.py b/applications/Chat/coati/models/base/__init__.py
index c5f748a0c85a..5c9905bb2224 100644
--- a/applications/Chat/coati/models/base/__init__.py
+++ b/applications/Chat/coati/models/base/__init__.py
@@ -9,7 +9,7 @@
def get_base_model(model: Union[Actor, Critic, RewardModel]) -> nn.Module:
"""Get the base model of our wrapper classes.
- For Actor, Critic and RewardModel, return ``model.model``,
+ For Actor, Critic and RewardModel, return ``model.model``,
it's usually a ``transformers.PreTrainedModel``.
Args:
@@ -18,9 +18,10 @@ def get_base_model(model: Union[Actor, Critic, RewardModel]) -> nn.Module:
Returns:
nn.Module: the base model
"""
- assert isinstance(model, (Actor, Critic, RewardModel)), \
- f'Expect Actor, Critic or RewardModel, got {type(model)}, use unwrap_model first.'
+ assert isinstance(
+ model, (Actor, Critic, RewardModel)
+ ), f"Expect Actor, Critic or RewardModel, got {type(model)}, use unwrap_model first."
return model.model
-__all__ = ['Actor', 'Critic', 'RewardModel', 'get_base_model']
+__all__ = ["Actor", "Critic", "RewardModel", "get_base_model"]
diff --git a/applications/Chat/coati/models/base/actor.py b/applications/Chat/coati/models/base/actor.py
index 6842f81d9b87..8b2b81ed071c 100644
--- a/applications/Chat/coati/models/base/actor.py
+++ b/applications/Chat/coati/models/base/actor.py
@@ -16,18 +16,18 @@ class Actor(LoRAModule):
lora_train_bias (str): LoRA bias training mode.
"""
- def __init__(self, model: nn.Module, lora_rank: int = 0, lora_train_bias: str = 'none') -> None:
+ def __init__(self, model: nn.Module, lora_rank: int = 0, lora_train_bias: str = "none") -> None:
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
self.model = model
self.convert_to_lora()
def forward(
- self,
- input_ids: torch.LongTensor,
- attention_mask: Optional[torch.Tensor] = None,
- **model_kwargs, # HACK: `generate` method may pass more kwargs
+ self,
+ input_ids: torch.LongTensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ **model_kwargs,
) -> torch.Tensor:
- """Returns model output.
- """
+ """Returns model output."""
output = self.model(input_ids, attention_mask=attention_mask, **model_kwargs)
return output
+
diff --git a/applications/Chat/coati/models/base/critic.py b/applications/Chat/coati/models/base/critic.py
index e68a743a7762..8672365f5783 100644
--- a/applications/Chat/coati/models/base/critic.py
+++ b/applications/Chat/coati/models/base/critic.py
@@ -1,10 +1,7 @@
-from typing import Optional
-
import torch
import torch.nn as nn
from ..lora import LoRAModule
-from ..utils import masked_mean
class Critic(LoRAModule):
@@ -19,36 +16,19 @@ class Critic(LoRAModule):
"""
def __init__(
- self,
- model: nn.Module,
- value_head: nn.Module,
- lora_rank: int = 0,
- lora_train_bias: str = 'none',
- use_action_mask: bool = False,
+ self, model: nn.Module, value_head: nn.Module, lora_rank: int = 0, lora_train_bias: str = "none"
) -> None:
-
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
self.model = model
self.value_head = value_head
- self.use_action_mask = use_action_mask
self.convert_to_lora()
- def forward(self,
- sequences: torch.LongTensor,
- action_mask: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+ def forward(self, sequences: torch.LongTensor, attention_mask: torch.Tensor) -> torch.Tensor:
outputs = self.model(sequences, attention_mask=attention_mask)
- last_hidden_states = outputs['last_hidden_state']
-
- values = self.value_head(last_hidden_states).squeeze(-1)
-
- if action_mask is not None and self.use_action_mask:
- num_actions = action_mask.size(1)
- prompt_mask = attention_mask[:, :-num_actions]
- values = values[:, :-num_actions]
- value = masked_mean(values, prompt_mask, dim=1)
- return value
-
- values = values[:, :-1]
- value = values.mean(dim=1)
- return value
+ last_hidden_states = outputs["last_hidden_state"]
+ sequence_lengths = torch.max(attention_mask * torch.arange(sequences.size(1), device=sequences.device), dim=1)[
+ 0
+ ]
+ sequence_hidden_states = last_hidden_states[torch.arange(last_hidden_states.size(0)), sequence_lengths]
+ values = self.value_head(sequence_hidden_states).squeeze(1) # ensure shape is (B, )
+ return values
diff --git a/applications/Chat/coati/models/base/reward_model.py b/applications/Chat/coati/models/base/reward_model.py
index ce8c0a1d3568..e9545d1cddaf 100644
--- a/applications/Chat/coati/models/base/reward_model.py
+++ b/applications/Chat/coati/models/base/reward_model.py
@@ -17,11 +17,13 @@ class RewardModel(LoRAModule):
lora_train_bias (str): LoRA bias training mode.
"""
- def __init__(self,
- model: nn.Module,
- value_head: Optional[nn.Module] = None,
- lora_rank: int = 0,
- lora_train_bias: str = 'none') -> None:
+ def __init__(
+ self,
+ model: nn.Module,
+ value_head: Optional[nn.Module] = None,
+ lora_rank: int = 0,
+ lora_train_bias: str = "none",
+ ) -> None:
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
self.model = model
self.convert_to_lora()
@@ -33,9 +35,12 @@ def __init__(self,
else:
self.value_head = nn.Linear(model.config.n_embd, 1)
- def forward(self, sequences: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+ def forward(self, sequences: torch.LongTensor, attention_mask: torch.Tensor) -> torch.Tensor:
outputs = self.model(sequences, attention_mask=attention_mask)
- last_hidden_states = outputs['last_hidden_state']
- values = self.value_head(last_hidden_states)[:, :-1]
- value = values.mean(dim=1).squeeze(1) # ensure shape is (B)
- return value
+ last_hidden_states = outputs["last_hidden_state"]
+ sequence_lengths = torch.max(attention_mask * torch.arange(sequences.size(1), device=sequences.device), dim=1)[
+ 0
+ ]
+ sequence_hidden_states = last_hidden_states[torch.arange(last_hidden_states.size(0)), sequence_lengths]
+ values = self.value_head(sequence_hidden_states).squeeze(1) # ensure shape is (B, )
+ return values
diff --git a/applications/Chat/coati/models/bloom/__init__.py b/applications/Chat/coati/models/bloom/__init__.py
index d0e7f7b1ef94..7af199a67d3b 100644
--- a/applications/Chat/coati/models/bloom/__init__.py
+++ b/applications/Chat/coati/models/bloom/__init__.py
@@ -2,4 +2,4 @@
from .bloom_critic import BLOOMCritic
from .bloom_rm import BLOOMRM
-__all__ = ['BLOOMActor', 'BLOOMCritic', 'BLOOMRM']
+__all__ = ["BLOOMActor", "BLOOMCritic", "BLOOMRM"]
diff --git a/applications/Chat/coati/models/bloom/bloom_actor.py b/applications/Chat/coati/models/bloom/bloom_actor.py
index d7577f096493..73855a2245e7 100644
--- a/applications/Chat/coati/models/bloom/bloom_actor.py
+++ b/applications/Chat/coati/models/bloom/bloom_actor.py
@@ -1,7 +1,6 @@
from typing import Optional
-import torch
-from transformers import BloomConfig, BloomForCausalLM, BloomModel
+from transformers import BloomConfig, BloomForCausalLM
from ..base import Actor
@@ -18,12 +17,14 @@ class BLOOMActor(Actor):
lora_train_bias (str): LoRA bias training mode.
"""
- def __init__(self,
- pretrained: str = None,
- config: Optional[BloomConfig] = None,
- checkpoint: bool = False,
- lora_rank: int = 0,
- lora_train_bias: str = 'none') -> None:
+ def __init__(
+ self,
+ pretrained: str = None,
+ config: Optional[BloomConfig] = None,
+ checkpoint: bool = False,
+ lora_rank: int = 0,
+ lora_train_bias: str = "none",
+ ) -> None:
if pretrained is not None:
model = BloomForCausalLM.from_pretrained(pretrained)
elif config is not None:
diff --git a/applications/Chat/coati/models/bloom/bloom_critic.py b/applications/Chat/coati/models/bloom/bloom_critic.py
index a3716ca94138..b2d838f7ffc5 100644
--- a/applications/Chat/coati/models/bloom/bloom_critic.py
+++ b/applications/Chat/coati/models/bloom/bloom_critic.py
@@ -1,8 +1,7 @@
from typing import Optional
-import torch
import torch.nn as nn
-from transformers import BloomConfig, BloomForCausalLM, BloomModel
+from transformers import BloomConfig, BloomModel
from ..base import Critic
@@ -18,12 +17,14 @@ class BLOOMCritic(Critic):
lora_train_bias (str): LoRA bias training mode.
"""
- def __init__(self,
- pretrained: str = None,
- config: Optional[BloomConfig] = None,
- lora_rank: int = 0,
- lora_train_bias: str = 'none',
- **kwargs) -> None:
+ def __init__(
+ self,
+ pretrained: str = None,
+ config: Optional[BloomConfig] = None,
+ lora_rank: int = 0,
+ lora_train_bias: str = "none",
+ **kwargs,
+ ) -> None:
if pretrained is not None:
model = BloomModel.from_pretrained(pretrained)
elif config is not None:
diff --git a/applications/Chat/coati/models/bloom/bloom_rm.py b/applications/Chat/coati/models/bloom/bloom_rm.py
index e6ca9b1d4851..c09457ddc8c7 100644
--- a/applications/Chat/coati/models/bloom/bloom_rm.py
+++ b/applications/Chat/coati/models/bloom/bloom_rm.py
@@ -1,7 +1,7 @@
from typing import Optional
import torch.nn as nn
-from transformers import BloomConfig, BloomForCausalLM, BloomModel
+from transformers import BloomConfig, BloomModel
from ..base import RewardModel
@@ -17,11 +17,13 @@ class BLOOMRM(RewardModel):
lora_train_bias (str): LoRA bias training mode.
"""
- def __init__(self,
- pretrained: str = None,
- config: Optional[BloomConfig] = None,
- lora_rank: int = 0,
- lora_train_bias: str = 'none') -> None:
+ def __init__(
+ self,
+ pretrained: str = None,
+ config: Optional[BloomConfig] = None,
+ lora_rank: int = 0,
+ lora_train_bias: str = "none",
+ ) -> None:
if pretrained is not None:
model = BloomModel.from_pretrained(pretrained)
elif config is not None:
diff --git a/applications/Chat/coati/models/chatglm/__init__.py b/applications/Chat/coati/models/chatglm/__init__.py
index 373f19553fdc..5956f5a8e91b 100644
--- a/applications/Chat/coati/models/chatglm/__init__.py
+++ b/applications/Chat/coati/models/chatglm/__init__.py
@@ -1,3 +1,3 @@
from .chatglm_actor import ChatGLMActor
-__all__ = ['ChatGLMActor']
\ No newline at end of file
+__all__ = ["ChatGLMActor"]
diff --git a/applications/Chat/coati/models/chatglm/chatglm_actor.py b/applications/Chat/coati/models/chatglm/chatglm_actor.py
index c35d994e9319..00a61561ee47 100644
--- a/applications/Chat/coati/models/chatglm/chatglm_actor.py
+++ b/applications/Chat/coati/models/chatglm/chatglm_actor.py
@@ -1,11 +1,9 @@
from typing import Optional
-import torch
+from ..base import Actor
from .configuration_chatglm import ChatGLMConfig
from .modeling_chatglm import ChatGLMForConditionalGeneration
-from ..base import Actor
-
class ChatGLMActor(Actor):
"""
@@ -19,10 +17,9 @@ class ChatGLMActor(Actor):
do not support lora for now.
"""
- def __init__(self,
- pretrained: str = None,
- config: Optional[ChatGLMConfig] = None,
- checkpoint: bool = False) -> None:
+ def __init__(
+ self, pretrained: str = None, config: Optional[ChatGLMConfig] = None, checkpoint: bool = False
+ ) -> None:
if pretrained is not None:
model = ChatGLMForConditionalGeneration.from_pretrained(pretrained)
elif config is not None:
@@ -31,4 +28,4 @@ def __init__(self,
model = ChatGLMForConditionalGeneration(ChatGLMConfig())
if checkpoint:
model.gradient_checkpointing_enable()
- super().__init__(model, lora_rank=0, lora_train_bias='none')
+ super().__init__(model, lora_rank=0, lora_train_bias="none")
diff --git a/applications/Chat/coati/models/chatglm/chatglm_tokenizer.py b/applications/Chat/coati/models/chatglm/chatglm_tokenizer.py
index f7717f7e68b6..221ef044b470 100644
--- a/applications/Chat/coati/models/chatglm/chatglm_tokenizer.py
+++ b/applications/Chat/coati/models/chatglm/chatglm_tokenizer.py
@@ -2,15 +2,14 @@
This code is copied from https://huggingface.co/THUDM/chatglm-6b/blob/main/tokenization_chatglm.py
"""
"""Tokenization classes for ChatGLM."""
-from typing import List, Optional, Union
import os
+from typing import Dict, List, Optional, Union
-from transformers.tokenization_utils import PreTrainedTokenizer
-from transformers.utils import logging, PaddingStrategy
-from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
-from typing import Dict
-import sentencepiece as spm
import numpy as np
+import sentencepiece as spm
+from transformers.tokenization_utils import PreTrainedTokenizer
+from transformers.tokenization_utils_base import BatchEncoding, EncodedInput
+from transformers.utils import PaddingStrategy, logging
logger = logging.get_logger(__name__)
@@ -52,11 +51,11 @@ def __len__(self):
class SPTokenizer:
def __init__(
- self,
- vocab_file,
- num_image_tokens=20000,
- max_blank_length=80,
- byte_fallback=True,
+ self,
+ vocab_file,
+ num_image_tokens=20000,
+ max_blank_length=80,
+ byte_fallback=True,
):
assert vocab_file is not None
self.vocab_file = vocab_file
@@ -100,9 +99,7 @@ def _preprocess(self, text: str, linebreak=True, whitespaces=True):
text = self._encode_whitespaces(text, max_len=self.max_blank_length)
return text
- def encode(
- self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True
- ) -> List[int]:
+ def encode(self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True) -> List[int]:
"""
@param text: Text to encode.
@param linebreak: Whether to encode newline (\n) in text.
@@ -136,9 +133,7 @@ def decode_tokens(self, tokens: List[str]) -> str:
text = self.postprocess(text)
return text
- def tokenize(
- self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True
- ) -> List[str]:
+ def tokenize(self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True) -> List[str]:
"""
@param text: Text to encode.
@param linebreak: Whether to encode newline (\n) in text.
@@ -181,20 +176,20 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
model_input_names = ["input_ids", "attention_mask", "position_ids"]
def __init__(
- self,
- vocab_file,
- do_lower_case=False,
- remove_space=False,
- bos_token='',
- eos_token='',
- end_token='',
- mask_token='[MASK]',
- gmask_token='[gMASK]',
- padding_side="left",
- pad_token="",
- unk_token="",
- num_image_tokens=20000,
- **kwargs
+ self,
+ vocab_file,
+ do_lower_case=False,
+ remove_space=False,
+ bos_token="",
+ eos_token="",
+ end_token="",
+ mask_token="[MASK]",
+ gmask_token="[gMASK]",
+ padding_side="left",
+ pad_token="",
+ unk_token="",
+ num_image_tokens=20000,
+ **kwargs,
) -> None:
super().__init__(
do_lower_case=do_lower_case,
@@ -208,7 +203,7 @@ def __init__(
pad_token=pad_token,
unk_token=unk_token,
num_image_tokens=num_image_tokens,
- **kwargs
+ **kwargs,
)
self.do_lower_case = do_lower_case
@@ -243,11 +238,11 @@ def end_token_id(self) -> Optional[int]:
@property
def vocab_size(self):
- """ Returns vocab size """
+ """Returns vocab size"""
return self.sp_tokenizer.num_tokens
def get_vocab(self):
- """ Returns vocab as a dict """
+ """Returns vocab as a dict"""
vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
@@ -264,7 +259,7 @@ def preprocess_text(self, inputs):
return outputs
def _tokenize(self, text, **kwargs):
- """ Returns a tokenized string. """
+ """Returns a tokenized string."""
text = self.preprocess_text(text)
seq = self.sp_tokenizer.tokenize(text)
@@ -274,11 +269,7 @@ def _tokenize(self, text, **kwargs):
def convert_tokens_to_string(self, tokens: List[str]) -> str:
return self.sp_tokenizer.decode_tokens(tokens)
- def _decode(
- self,
- token_ids: Union[int, List[int]],
- **kwargs
- ) -> str:
+ def _decode(self, token_ids: Union[int, List[int]], **kwargs) -> str:
if isinstance(token_ids, int):
token_ids = [token_ids]
if len(token_ids) == 0:
@@ -288,7 +279,7 @@ def _decode(
return super()._decode(token_ids, **kwargs)
def _convert_token_to_id(self, token):
- """ Converts a token (str) in an id using the vocab. """
+ """Converts a token (str) in an id using the vocab."""
return self.sp_tokenizer[token]
def _convert_id_to_token(self, index):
@@ -309,13 +300,11 @@ def save_vocabulary(self, save_directory, filename_prefix=None):
`Tuple(str)`: Paths to the files saved.
"""
if os.path.isdir(save_directory):
- vocab_file = os.path.join(
- save_directory, self.vocab_files_names["vocab_file"]
- )
+ vocab_file = os.path.join(save_directory, self.vocab_files_names["vocab_file"])
else:
vocab_file = save_directory
- with open(self.vocab_file, 'rb') as fin:
+ with open(self.vocab_file, "rb") as fin:
proto_str = fin.read()
with open(vocab_file, "wb") as writer:
@@ -324,7 +313,7 @@ def save_vocabulary(self, save_directory, filename_prefix=None):
return (vocab_file,)
def build_inputs_with_special_tokens(
- self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
@@ -343,19 +332,19 @@ def build_inputs_with_special_tokens(
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
"""
gmask_id = self.sp_tokenizer[self.gmask_token]
- eos_id = self.sp_tokenizer[self.eos_token]
+ self.sp_tokenizer[self.eos_token]
token_ids_0 = token_ids_0 + [gmask_id, self.sp_tokenizer[self.bos_token]]
if token_ids_1 is not None:
token_ids_0 = token_ids_0 + token_ids_1
return token_ids_0
def _pad(
- self,
- encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
- max_length: Optional[int] = None,
- padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
- pad_to_multiple_of: Optional[int] = None,
- return_attention_mask: Optional[bool] = None,
+ self,
+ encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
+ max_length: Optional[int] = None,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ pad_to_multiple_of: Optional[int] = None,
+ return_attention_mask: Optional[bool] = None,
) -> dict:
"""
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
@@ -421,17 +410,23 @@ def _pad(
mask_position = required_input.index(mask_token)
position_ids[context_length:] = mask_position
block_position_ids = np.concatenate(
- [np.zeros(context_length, dtype=np.int64),
- np.arange(1, seq_length - context_length + 1, dtype=np.int64)])
+ [
+ np.zeros(context_length, dtype=np.int64),
+ np.arange(1, seq_length - context_length + 1, dtype=np.int64),
+ ]
+ )
encoded_inputs["position_ids"] = np.stack([position_ids, block_position_ids], axis=0)
if needs_to_be_padded:
difference = max_length - len(required_input)
if "attention_mask" in encoded_inputs:
- encoded_inputs["attention_mask"] = np.pad(encoded_inputs["attention_mask"],
- pad_width=[(0, 0), (difference, 0), (difference, 0)],
- mode='constant', constant_values=True)
+ encoded_inputs["attention_mask"] = np.pad(
+ encoded_inputs["attention_mask"],
+ pad_width=[(0, 0), (difference, 0), (difference, 0)],
+ mode="constant",
+ constant_values=True,
+ )
if "token_type_ids" in encoded_inputs:
encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
"token_type_ids"
@@ -439,8 +434,9 @@ def _pad(
if "special_tokens_mask" in encoded_inputs:
encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
if "position_ids" in encoded_inputs:
- encoded_inputs["position_ids"] = np.pad(encoded_inputs["position_ids"],
- pad_width=[(0, 0), (difference, 0)])
+ encoded_inputs["position_ids"] = np.pad(
+ encoded_inputs["position_ids"], pad_width=[(0, 0), (difference, 0)]
+ )
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
- return encoded_inputs
\ No newline at end of file
+ return encoded_inputs
diff --git a/applications/Chat/coati/models/chatglm/configuration_chatglm.py b/applications/Chat/coati/models/chatglm/configuration_chatglm.py
index d0e3f6cc63d7..a6d2ccd18715 100644
--- a/applications/Chat/coati/models/chatglm/configuration_chatglm.py
+++ b/applications/Chat/coati/models/chatglm/configuration_chatglm.py
@@ -56,30 +56,29 @@ class ChatGLMConfig(PretrainedConfig):
>>> # Accessing the model configuration
>>> configuration = model.config
- ```
-"""
+ ```"""
model_type = "chatglm"
def __init__(
- self,
- vocab_size=130528,
- hidden_size=4096,
- num_layers=28,
- num_attention_heads=32,
- layernorm_epsilon=1e-5,
- use_cache=True,
- bos_token_id=130004,
- eos_token_id=130005,
- mask_token_id=130000,
- gmask_token_id=130001,
- pad_token_id=3,
- max_sequence_length=2048,
- inner_hidden_size=16384,
- position_encoding_2d=True,
- quantization_bit=0,
- pre_seq_len=None,
- prefix_projection=False,
- **kwargs
+ self,
+ vocab_size=130528,
+ hidden_size=4096,
+ num_layers=28,
+ num_attention_heads=32,
+ layernorm_epsilon=1e-5,
+ use_cache=True,
+ bos_token_id=130004,
+ eos_token_id=130005,
+ mask_token_id=130000,
+ gmask_token_id=130001,
+ pad_token_id=3,
+ max_sequence_length=2048,
+ inner_hidden_size=16384,
+ position_encoding_2d=True,
+ quantization_bit=0,
+ pre_seq_len=None,
+ prefix_projection=False,
+ **kwargs,
):
self.num_layers = num_layers
self.vocab_size = vocab_size
@@ -99,9 +98,4 @@ def __init__(
self.pre_seq_len = pre_seq_len
self.prefix_projection = prefix_projection
- super().__init__(
- pad_token_id=pad_token_id,
- bos_token_id=bos_token_id,
- eos_token_id=eos_token_id,
- **kwargs
- )
\ No newline at end of file
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
diff --git a/applications/Chat/coati/models/chatglm/modeling_chatglm.py b/applications/Chat/coati/models/chatglm/modeling_chatglm.py
index 77e7d0d8ea09..d1d15c68ffd8 100644
--- a/applications/Chat/coati/models/chatglm/modeling_chatglm.py
+++ b/applications/Chat/coati/models/chatglm/modeling_chatglm.py
@@ -4,41 +4,40 @@
""" PyTorch ChatGLM model. """
-import math
import copy
+import math
import os
-import warnings
import re
import sys
+import warnings
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
-import torch.utils.checkpoint
import torch.nn.functional as F
+import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss, LayerNorm
from torch.nn.utils import skip_init
-from typing import Optional, Tuple, Union, List, Callable, Dict, Any
-
-from transformers.utils import (
- add_code_sample_docstrings,
- add_start_docstrings,
- add_start_docstrings_to_model_forward,
-)
+from transformers.generation.logits_process import LogitsProcessor
+from transformers.generation.utils import GenerationConfig, LogitsProcessorList, ModelOutput, StoppingCriteriaList
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
- CausalLMOutputWithPast,
BaseModelOutputWithPastAndCrossAttentions,
+ CausalLMOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
-from transformers.utils import logging
-from transformers.generation.logits_process import LogitsProcessor
-from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
+from transformers.utils import (
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+)
from .configuration_chatglm import ChatGLMConfig
# flags required to enable jit fusion kernels
-if sys.platform != 'darwin':
+if sys.platform != "darwin":
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
@@ -93,8 +92,8 @@ def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path):
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if any(
- n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
- for n in name
+ n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
+ for n in name
):
logger.info(f"Skipping {'/'.join(name)}")
continue
@@ -127,7 +126,7 @@ def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path):
array = np.transpose(array)
try:
assert (
- pointer.shape == array.shape
+ pointer.shape == array.shape
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
except AssertionError as e:
e.args += (pointer.shape, array.shape)
@@ -153,7 +152,7 @@ def __init__(self, config):
self.trans = torch.nn.Sequential(
torch.nn.Linear(config.hidden_size, config.hidden_size),
torch.nn.Tanh(),
- torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2)
+ torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2),
)
else:
self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_layers * config.hidden_size * 2)
@@ -170,8 +169,7 @@ def forward(self, prefix: torch.Tensor):
@torch.jit.script
def gelu_impl(x):
"""OpenAI's gelu implementation."""
- return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x *
- (1.0 + 0.044715 * x * x)))
+ return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)))
def gelu(x):
@@ -181,21 +179,22 @@ def gelu(x):
class RotaryEmbedding(torch.nn.Module):
def __init__(self, dim, base=10000, precision=torch.half, learnable=False):
super().__init__()
- inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
inv_freq = inv_freq.half()
self.learnable = learnable
if learnable:
self.inv_freq = torch.nn.Parameter(inv_freq)
self.max_seq_len_cached = None
else:
- self.register_buffer('inv_freq', inv_freq)
+ self.register_buffer("inv_freq", inv_freq)
self.max_seq_len_cached = None
self.cos_cached = None
self.sin_cached = None
self.precision = precision
- def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
- error_msgs):
+ def _load_from_state_dict(
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
+ ):
pass
def forward(self, x, seq_dim=1, seq_len=None):
@@ -204,7 +203,7 @@ def forward(self, x, seq_dim=1, seq_len=None):
if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
self.max_seq_len_cached = None if self.learnable else seq_len
t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
- freqs = torch.einsum('i,j->ij', t, self.inv_freq)
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
if self.precision == torch.bfloat16:
@@ -230,30 +229,31 @@ def _apply(self, fn):
def rotate_half(x):
- x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions
@torch.jit.script
def apply_rotary_pos_emb_index(q, k, cos, sin, position_id):
# position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn]
- cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \
- F.embedding(position_id, sin.squeeze(1)).unsqueeze(2)
+ cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), F.embedding(
+ position_id, sin.squeeze(1)
+ ).unsqueeze(2)
q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
return q, k
def attention_fn(
- self,
- query_layer,
- key_layer,
- value_layer,
- attention_mask,
- hidden_size_per_partition,
- layer_id,
- layer_past=None,
- scaling_attention_score=True,
- use_cache=False,
+ self,
+ query_layer,
+ key_layer,
+ value_layer,
+ attention_mask,
+ hidden_size_per_partition,
+ layer_id,
+ layer_past=None,
+ scaling_attention_score=True,
+ use_cache=False,
):
if layer_past is not None:
past_key, past_value = layer_past[0], layer_past[1]
@@ -285,7 +285,9 @@ def attention_fn(
key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
matmul_result = torch.zeros(
- 1, 1, 1,
+ 1,
+ 1,
+ 1,
dtype=query_layer.dtype,
device=query_layer.device,
)
@@ -355,9 +357,17 @@ def default_init(cls, *args, **kwargs):
class SelfAttention(torch.nn.Module):
- def __init__(self, hidden_size, num_attention_heads,
- layer_id, hidden_size_per_attention_head=None, bias=True,
- params_dtype=torch.float, position_encoding_2d=True, empty_init=True):
+ def __init__(
+ self,
+ hidden_size,
+ num_attention_heads,
+ layer_id,
+ hidden_size_per_attention_head=None,
+ bias=True,
+ params_dtype=torch.float,
+ position_encoding_2d=True,
+ empty_init=True,
+ ):
if empty_init:
init_method = skip_init
else:
@@ -410,8 +420,7 @@ def attention_mask_func(attention_scores, attention_mask):
attention_scores.masked_fill_(attention_mask, -10000.0)
return attention_scores
- def split_tensor_along_last_dim(self, tensor, num_partitions,
- contiguous_split_chunks=False):
+ def split_tensor_along_last_dim(self, tensor, num_partitions, contiguous_split_chunks=False):
"""Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
@@ -431,14 +440,14 @@ def split_tensor_along_last_dim(self, tensor, num_partitions,
return tensor_list
def forward(
- self,
- hidden_states: torch.Tensor,
- position_ids,
- attention_mask: torch.Tensor,
- layer_id,
- layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
- use_cache: bool = False,
- output_attentions: bool = False,
+ self,
+ hidden_states: torch.Tensor,
+ position_ids,
+ attention_mask: torch.Tensor,
+ layer_id,
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ use_cache: bool = False,
+ output_attentions: bool = False,
):
"""
hidden_states: [seq_len, batch, hidden_size]
@@ -462,8 +471,10 @@ def forward(
q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1))
k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1))
cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1)
- position_ids, block_position_ids = position_ids[:, 0, :].transpose(0, 1).contiguous(), \
- position_ids[:, 1, :].transpose(0, 1).contiguous()
+ position_ids, block_position_ids = (
+ position_ids[:, 0, :].transpose(0, 1).contiguous(),
+ position_ids[:, 1, :].transpose(0, 1).contiguous(),
+ )
q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids)
q2, k2 = apply_rotary_pos_emb_index(q2, k2, cos, sin, block_position_ids)
query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1))
@@ -484,7 +495,7 @@ def forward(
hidden_size_per_partition=self.hidden_size_per_partition,
layer_id=layer_id,
layer_past=layer_past,
- use_cache=use_cache
+ use_cache=use_cache,
)
output = self.dense(context_layer)
@@ -509,8 +520,16 @@ def forward(self, x):
class GLU(torch.nn.Module):
- def __init__(self, hidden_size, inner_hidden_size=None,
- layer_id=None, bias=True, activation_func=gelu, params_dtype=torch.float, empty_init=True):
+ def __init__(
+ self,
+ hidden_size,
+ inner_hidden_size=None,
+ layer_id=None,
+ bias=True,
+ activation_func=gelu,
+ params_dtype=torch.float,
+ empty_init=True,
+ ):
super(GLU, self).__init__()
if empty_init:
init_method = skip_init
@@ -557,19 +576,19 @@ def forward(self, hidden_states):
class GLMBlock(torch.nn.Module):
def __init__(
- self,
- hidden_size,
- num_attention_heads,
- layernorm_epsilon,
- layer_id,
- inner_hidden_size=None,
- hidden_size_per_attention_head=None,
- layernorm=LayerNorm,
- use_bias=True,
- params_dtype=torch.float,
- num_layers=28,
- position_encoding_2d=True,
- empty_init=True
+ self,
+ hidden_size,
+ num_attention_heads,
+ layernorm_epsilon,
+ layer_id,
+ inner_hidden_size=None,
+ hidden_size_per_attention_head=None,
+ layernorm=LayerNorm,
+ use_bias=True,
+ params_dtype=torch.float,
+ num_layers=28,
+ position_encoding_2d=True,
+ empty_init=True,
):
super(GLMBlock, self).__init__()
# Set output layer initialization if not provided.
@@ -590,7 +609,7 @@ def __init__(
bias=use_bias,
params_dtype=params_dtype,
position_encoding_2d=self.position_encoding_2d,
- empty_init=empty_init
+ empty_init=empty_init,
)
# Layernorm on the input data.
@@ -605,18 +624,18 @@ def __init__(
bias=use_bias,
layer_id=layer_id,
params_dtype=params_dtype,
- empty_init=empty_init
+ empty_init=empty_init,
)
def forward(
- self,
- hidden_states: torch.Tensor,
- position_ids,
- attention_mask: torch.Tensor,
- layer_id,
- layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
- use_cache: bool = False,
- output_attentions: bool = False,
+ self,
+ hidden_states: torch.Tensor,
+ position_ids,
+ attention_mask: torch.Tensor,
+ layer_id,
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ use_cache: bool = False,
+ output_attentions: bool = False,
):
"""
hidden_states: [seq_len, batch, hidden_size]
@@ -635,7 +654,7 @@ def forward(
layer_id=layer_id,
layer_past=layer_past,
use_cache=use_cache,
- output_attentions=output_attentions
+ output_attentions=output_attentions,
)
attention_output = attention_outputs[0]
@@ -702,10 +721,15 @@ def get_position_ids(self, input_ids, mask_positions, device, use_gmasks=None):
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
for i, context_length in enumerate(context_lengths):
position_ids[i, context_length:] = mask_positions[i]
- block_position_ids = [torch.cat((
- torch.zeros(context_length, dtype=torch.long, device=device),
- torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
- )) for context_length in context_lengths]
+ block_position_ids = [
+ torch.cat(
+ (
+ torch.zeros(context_length, dtype=torch.long, device=device),
+ torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1,
+ )
+ )
+ for context_length in context_lengths
+ ]
block_position_ids = torch.stack(block_position_ids, dim=0)
position_ids = torch.stack((position_ids, block_position_ids), dim=1)
else:
@@ -823,9 +847,7 @@ def __init__(self, config: ChatGLMConfig, empty_init=True):
self.prefix_projection = config.prefix_projection
self.word_embeddings = init_method(
- torch.nn.Embedding,
- num_embeddings=self.vocab_size, embedding_dim=self.hidden_size,
- dtype=self.params_dtype
+ torch.nn.Embedding, num_embeddings=self.vocab_size, embedding_dim=self.hidden_size, dtype=self.params_dtype
)
self.gradient_checkpointing = False
@@ -841,12 +863,10 @@ def get_layer(layer_id):
use_bias=True,
params_dtype=self.params_dtype,
position_encoding_2d=self.position_encoding_2d,
- empty_init=empty_init
+ empty_init=empty_init,
)
- self.layers = torch.nn.ModuleList(
- [get_layer(layer_id) for layer_id in range(self.num_layers)]
- )
+ self.layers = torch.nn.ModuleList([get_layer(layer_id) for layer_id in range(self.num_layers)])
# Final layer norm before output.
self.final_layernorm = LayerNorm(self.hidden_size, eps=self.layernorm_epsilon)
@@ -876,7 +896,7 @@ def get_prompt(self, batch_size, device, dtype=torch.half):
self.pre_seq_len,
self.num_layers * 2,
self.num_attention_heads,
- self.hidden_size // self.num_attention_heads
+ self.hidden_size // self.num_attention_heads,
)
# seq_len, b, nh, hidden_size
past_key_values = self.dropout(past_key_values)
@@ -891,18 +911,17 @@ def get_prompt(self, batch_size, device, dtype=torch.half):
config_class=_CONFIG_FOR_DOC,
)
def forward(
- self,
- input_ids: Optional[torch.LongTensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
- inputs_embeds: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
+ inputs_embeds: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPast]:
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -931,17 +950,14 @@ def forward(
if past_key_values is None:
if self.pre_seq_len is not None:
- past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device,
- dtype=inputs_embeds.dtype)
+ past_key_values = self.get_prompt(
+ batch_size=input_ids.shape[0], device=input_ids.device, dtype=inputs_embeds.dtype
+ )
else:
past_key_values = tuple([None] * len(self.layers))
if attention_mask is None:
- attention_mask = self.get_masks(
- input_ids,
- device=input_ids.device
- )
-
+ attention_mask = self.get_masks(input_ids, device=input_ids.device)
if position_ids is None:
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
@@ -955,15 +971,13 @@ def forward(
use_gmasks.append(use_gmask)
position_ids = self.get_position_ids(
- input_ids,
- mask_positions=mask_positions,
- device=input_ids.device,
- use_gmasks=use_gmasks
+ input_ids, mask_positions=mask_positions, device=input_ids.device, use_gmasks=use_gmasks
)
if self.pre_seq_len is not None and attention_mask is not None:
prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to(
- attention_mask.device)
+ attention_mask.device
+ )
prefix_attention_mask = (prefix_attention_mask < 0.5).bool()
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
@@ -980,7 +994,6 @@ def forward(
attention_mask = attention_mask.to(hidden_states.device)
for i, layer in enumerate(self.layers):
-
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_past = past_key_values[i]
@@ -994,7 +1007,7 @@ def forward(
torch.tensor(i),
layer_past,
use_cache,
- output_attentions
+ output_attentions,
)
else:
layer_ret = layer(
@@ -1004,7 +1017,7 @@ def forward(
layer_id=torch.tensor(i),
layer_past=layer_past,
use_cache=use_cache,
- output_attentions=output_attentions
+ output_attentions=output_attentions,
)
hidden_states = layer_ret[0]
@@ -1049,13 +1062,7 @@ def __init__(self, config: ChatGLMConfig, empty_init=True):
self.transformer = ChatGLMModel(config, empty_init=empty_init)
- self.lm_head = init_method(
- nn.Linear,
- config.hidden_size,
- config.vocab_size,
- bias=False,
- dtype=torch.half
- )
+ self.lm_head = init_method(nn.Linear, config.hidden_size, config.vocab_size, bias=False, dtype=torch.half)
self.config = config
@@ -1087,32 +1094,29 @@ def _update_model_kwargs_for_generation(
attention_mask = model_kwargs["attention_mask"]
if attention_mask is not None and attention_mask.dtype == torch.bool:
attention_mask = torch.cat(
- [attention_mask, attention_mask.new_ones((*attention_mask.shape[:3], 1))], dim=3)
+ [attention_mask, attention_mask.new_ones((*attention_mask.shape[:3], 1))], dim=3
+ )
new_attention_mask = attention_mask[:, :, -1:].clone()
new_attention_mask[..., -1] = False
- model_kwargs["attention_mask"] = torch.cat(
- [attention_mask, new_attention_mask], dim=2
- )
+ model_kwargs["attention_mask"] = torch.cat([attention_mask, new_attention_mask], dim=2)
# update position ids
if "position_ids" in model_kwargs:
position_ids = model_kwargs["position_ids"]
new_position_id = position_ids[..., -1:].clone()
new_position_id[:, 1, :] += 1
- model_kwargs["position_ids"] = torch.cat(
- [position_ids, new_position_id], dim=-1
- )
+ model_kwargs["position_ids"] = torch.cat([position_ids, new_position_id], dim=-1)
return model_kwargs
def prepare_inputs_for_generation(
- self,
- input_ids: torch.LongTensor,
- past: Optional[torch.Tensor] = None,
- past_key_values: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.Tensor] = None,
- **kwargs
+ self,
+ input_ids: torch.LongTensor,
+ past: Optional[torch.Tensor] = None,
+ past_key_values: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ **kwargs,
) -> dict:
batch_size, seq_length = input_ids.shape
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
@@ -1137,11 +1141,17 @@ def prepare_inputs_for_generation(
context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs]
if self.position_encoding_2d:
position_ids = torch.tensor(
- [[mask_position, seq_length - context_length] for mask_position, context_length in
- zip(mask_positions, context_lengths)], dtype=torch.long, device=input_ids.device).unsqueeze(-1)
+ [
+ [mask_position, seq_length - context_length]
+ for mask_position, context_length in zip(mask_positions, context_lengths)
+ ],
+ dtype=torch.long,
+ device=input_ids.device,
+ ).unsqueeze(-1)
else:
- position_ids = torch.tensor([mask_position for mask_position in mask_positions], dtype=torch.long,
- device=input_ids.device).unsqueeze(-1)
+ position_ids = torch.tensor(
+ [mask_position for mask_position in mask_positions], dtype=torch.long, device=input_ids.device
+ ).unsqueeze(-1)
if past is None:
past = past_key_values
@@ -1149,44 +1159,38 @@ def prepare_inputs_for_generation(
"input_ids": last_token,
"past_key_values": past,
"position_ids": position_ids,
- "attention_mask": attention_mask
+ "attention_mask": attention_mask,
}
else:
if attention_mask is not None and attention_mask.dtype != torch.bool:
logger.warning_once(f"The dtype of attention mask ({attention_mask.dtype}) is not bool")
attention_mask = None
if attention_mask is None:
- attention_mask = self.get_masks(
- input_ids,
- device=input_ids.device
- )
+ attention_mask = self.get_masks(input_ids, device=input_ids.device)
if position_ids is None:
position_ids = self.get_position_ids(
- input_ids,
- device=input_ids.device,
- mask_positions=mask_positions,
- use_gmasks=use_gmasks
+ input_ids, device=input_ids.device, mask_positions=mask_positions, use_gmasks=use_gmasks
)
return {
"input_ids": input_ids,
"past_key_values": past,
"position_ids": position_ids,
- "attention_mask": attention_mask
+ "attention_mask": attention_mask,
}
def forward(
- self,
- input_ids: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.Tensor] = None,
- labels: Optional[torch.Tensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
):
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
@@ -1235,7 +1239,7 @@ def forward(
@staticmethod
def _reorder_cache(
- past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
+ past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
"""
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
@@ -1268,15 +1272,33 @@ def process_response(self, response):
return response
@torch.no_grad()
- def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
- do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
+ def chat(
+ self,
+ tokenizer,
+ query: str,
+ history: List[Tuple[str, str]] = None,
+ max_length: int = 2048,
+ num_beams=1,
+ do_sample=True,
+ top_p=0.7,
+ temperature=0.95,
+ logits_processor=None,
+ **kwargs,
+ ):
if history is None:
history = []
if logits_processor is None:
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
- gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
- "temperature": temperature, "logits_processor": logits_processor, **kwargs}
+ gen_kwargs = {
+ "max_length": max_length,
+ "num_beams": num_beams,
+ "do_sample": do_sample,
+ "top_p": top_p,
+ "temperature": temperature,
+ "logits_processor": logits_processor,
+ **kwargs,
+ }
if not history:
prompt = query
else:
@@ -1287,22 +1309,38 @@ def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max
inputs = tokenizer([prompt], return_tensors="pt")
inputs = inputs.to(self.device)
outputs = self.generate(**inputs, **gen_kwargs)
- outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) :]
response = tokenizer.decode(outputs)
response = self.process_response(response)
history = history + [(query, response)]
return response, history
@torch.no_grad()
- def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048,
- do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
+ def stream_chat(
+ self,
+ tokenizer,
+ query: str,
+ history: List[Tuple[str, str]] = None,
+ max_length: int = 2048,
+ do_sample=True,
+ top_p=0.7,
+ temperature=0.95,
+ logits_processor=None,
+ **kwargs,
+ ):
if history is None:
history = []
if logits_processor is None:
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
- gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
- "temperature": temperature, "logits_processor": logits_processor, **kwargs}
+ gen_kwargs = {
+ "max_length": max_length,
+ "do_sample": do_sample,
+ "top_p": top_p,
+ "temperature": temperature,
+ "logits_processor": logits_processor,
+ **kwargs,
+ }
if not history:
prompt = query
else:
@@ -1313,7 +1351,7 @@ def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = No
inputs = tokenizer([prompt], return_tensors="pt")
inputs = inputs.to(self.device)
for outputs in self.stream_generate(**inputs, **gen_kwargs):
- outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) :]
response = tokenizer.decode(outputs)
response = self.process_response(response)
new_history = history + [(query, response)]
@@ -1321,13 +1359,13 @@ def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = No
@torch.no_grad()
def stream_generate(
- self,
- input_ids,
- generation_config: Optional[GenerationConfig] = None,
- logits_processor: Optional[LogitsProcessorList] = None,
- stopping_criteria: Optional[StoppingCriteriaList] = None,
- prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
- **kwargs,
+ self,
+ input_ids,
+ generation_config: Optional[GenerationConfig] = None,
+ logits_processor: Optional[LogitsProcessorList] = None,
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
+ **kwargs,
):
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
diff --git a/applications/Chat/coati/models/generation.py b/applications/Chat/coati/models/generation.py
index de0d63f95f50..4ab0cdc8a3ea 100644
--- a/applications/Chat/coati/models/generation.py
+++ b/applications/Chat/coati/models/generation.py
@@ -2,6 +2,7 @@
import torch
import torch.distributed as dist
+from transformers import PreTrainedTokenizer
from .base import Actor
@@ -16,9 +17,9 @@
from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper
-def _prepare_logits_processor(top_k: Optional[int] = None,
- top_p: Optional[float] = None,
- temperature: Optional[float] = None) -> LogitsProcessorList:
+def _prepare_logits_processor(
+ top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None
+) -> LogitsProcessorList:
processor_list = LogitsProcessorList()
if temperature is not None and temperature != 1.0:
processor_list.append(TemperatureLogitsWarper(temperature))
@@ -37,18 +38,20 @@ def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool:
return unfinished_sequences.max() == 0
-def _sample(model: Actor,
- input_ids: torch.Tensor,
- max_length: int,
- early_stopping: bool = False,
- eos_token_id: Optional[int] = None,
- pad_token_id: Optional[int] = None,
- top_k: Optional[int] = None,
- top_p: Optional[float] = None,
- temperature: Optional[float] = None,
- prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
- update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
- **model_kwargs) -> torch.Tensor:
+def _sample(
+ model: Actor,
+ input_ids: torch.Tensor,
+ max_length: int,
+ early_stopping: bool = False,
+ eos_token_id: Optional[int] = None,
+ pad_token_id: Optional[int] = None,
+ top_k: Optional[int] = None,
+ top_p: Optional[float] = None,
+ temperature: Optional[float] = None,
+ prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
+ update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
+ **model_kwargs,
+) -> torch.Tensor:
if input_ids.size(1) >= max_length:
return input_ids
@@ -56,12 +59,13 @@ def _sample(model: Actor,
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
for _ in range(input_ids.size(1), max_length):
- model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) \
- if prepare_inputs_fn is not None else {'input_ids': input_ids}
+ model_inputs = (
+ prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {"input_ids": input_ids}
+ )
outputs = model(**model_inputs)
- next_token_logits = outputs['logits'][:, -1, :]
- # pre-process distribution
+ # NOTE: this is correct only in left padding mode
+ next_token_logits = outputs["logits"][:, -1, :]
next_token_logits = logits_processor(input_ids, next_token_logits)
# sample
probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float)
@@ -69,8 +73,7 @@ def _sample(model: Actor,
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
- if pad_token_id is None:
- raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
+ assert pad_token_id is not None, "If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
# update generated ids, model inputs for next step
@@ -90,20 +93,21 @@ def _sample(model: Actor,
@torch.no_grad()
-def generate(model: Actor,
- input_ids: torch.Tensor,
- max_length: int,
- num_beams: int = 1,
- do_sample: bool = True,
- early_stopping: bool = False,
- eos_token_id: Optional[int] = None,
- pad_token_id: Optional[int] = None,
- top_k: Optional[int] = None,
- top_p: Optional[float] = None,
- temperature: Optional[float] = None,
- prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
- update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
- **model_kwargs) -> torch.Tensor:
+def generate(
+ model: Actor,
+ input_ids: torch.Tensor,
+ tokenizer: PreTrainedTokenizer,
+ max_length: int,
+ num_beams: int = 1,
+ do_sample: bool = True,
+ early_stopping: bool = False,
+ top_k: Optional[int] = None,
+ top_p: Optional[float] = None,
+ temperature: Optional[float] = None,
+ prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
+ update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
+ **model_kwargs,
+) -> torch.Tensor:
"""Generate token sequence. The returned sequence is input_ids + generated_tokens.
Args:
@@ -113,34 +117,35 @@ def generate(model: Actor,
num_beams (int, optional): number of beams. Defaults to 1.
do_sample (bool, optional): whether to do sample. Defaults to True.
early_stopping (bool, optional): if True, the sequence length may be smaller than max_length due to finding eos. Defaults to False.
- eos_token_id (Optional[int], optional): end of sequence token id. Defaults to None.
- pad_token_id (Optional[int], optional): pad token id. Defaults to None.
top_k (Optional[int], optional): the number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to None.
top_p (Optional[float], optional): If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. Defaults to None.
temperature (Optional[float], optional): The value used to module the next token probabilities. Defaults to None.
prepare_inputs_fn (Optional[Callable[[torch.Tensor, Any], dict]], optional): Function to preprocess model inputs. Arguments of this function should be input_ids and model_kwargs. Defaults to None.
update_model_kwargs_fn (Optional[Callable[[dict, Any], dict]], optional): Function to update model_kwargs based on outputs. Arguments of this function should be outputs and model_kwargs. Defaults to None.
"""
- is_greedy_gen_mode = ((num_beams == 1) and do_sample is False)
- is_sample_gen_mode = ((num_beams == 1) and do_sample is True)
- is_beam_gen_mode = ((num_beams > 1) and do_sample is False)
+ assert tokenizer.padding_side == "left", "Current generation only supports left padding."
+ is_greedy_gen_mode = (num_beams == 1) and do_sample is False
+ is_sample_gen_mode = (num_beams == 1) and do_sample is True
+ is_beam_gen_mode = (num_beams > 1) and do_sample is False
if is_greedy_gen_mode:
# run greedy search
raise NotImplementedError
elif is_sample_gen_mode:
# run sample
- return _sample(model,
- input_ids,
- max_length,
- early_stopping=early_stopping,
- eos_token_id=eos_token_id,
- pad_token_id=pad_token_id,
- top_k=top_k,
- top_p=top_p,
- temperature=temperature,
- prepare_inputs_fn=prepare_inputs_fn,
- update_model_kwargs_fn=update_model_kwargs_fn,
- **model_kwargs)
+ return _sample(
+ model,
+ input_ids,
+ max_length,
+ early_stopping=early_stopping,
+ eos_token_id=tokenizer.eos_token_id,
+ pad_token_id=tokenizer.pad_token_id,
+ top_k=top_k,
+ top_p=top_p,
+ temperature=temperature,
+ prepare_inputs_fn=prepare_inputs_fn,
+ update_model_kwargs_fn=update_model_kwargs_fn,
+ **model_kwargs,
+ )
elif is_beam_gen_mode:
raise NotImplementedError
else:
diff --git a/applications/Chat/coati/models/gpt/__init__.py b/applications/Chat/coati/models/gpt/__init__.py
index 63dc5ab0f5ea..823cf4a75e0d 100644
--- a/applications/Chat/coati/models/gpt/__init__.py
+++ b/applications/Chat/coati/models/gpt/__init__.py
@@ -2,4 +2,4 @@
from .gpt_critic import GPTCritic
from .gpt_rm import GPTRM
-__all__ = ['GPTActor', 'GPTCritic', 'GPTRM']
+__all__ = ["GPTActor", "GPTCritic", "GPTRM"]
diff --git a/applications/Chat/coati/models/gpt/gpt_actor.py b/applications/Chat/coati/models/gpt/gpt_actor.py
index ae9d669f1f56..a7e4b9bc3e22 100644
--- a/applications/Chat/coati/models/gpt/gpt_actor.py
+++ b/applications/Chat/coati/models/gpt/gpt_actor.py
@@ -18,13 +18,15 @@ class GPTActor(Actor):
lora_train_bias (str): Bias training strategy for the LoRa layer.
"""
- def __init__(self,
- pretrained: Optional[str] = None,
- config: Optional[GPT2Config] = None,
- checkpoint: bool = False,
- lora_rank: int = 0,
- lora_train_bias: str = 'none',
- **kwargs) -> None:
+ def __init__(
+ self,
+ pretrained: Optional[str] = None,
+ config: Optional[GPT2Config] = None,
+ checkpoint: bool = False,
+ lora_rank: int = 0,
+ lora_train_bias: str = "none",
+ **kwargs,
+ ) -> None:
if pretrained is not None:
model = GPT2LMHeadModel.from_pretrained(pretrained)
elif config is not None:
diff --git a/applications/Chat/coati/models/gpt/gpt_critic.py b/applications/Chat/coati/models/gpt/gpt_critic.py
index 01e1cd10ef57..22ab36dea276 100644
--- a/applications/Chat/coati/models/gpt/gpt_critic.py
+++ b/applications/Chat/coati/models/gpt/gpt_critic.py
@@ -18,12 +18,14 @@ class GPTCritic(Critic):
lora_train_bias (str): LoRA bias training mode.
"""
- def __init__(self,
- pretrained: Optional[str] = None,
- config: Optional[GPT2Config] = None,
- lora_rank: int = 0,
- lora_train_bias: str = 'none',
- **kwargs) -> None:
+ def __init__(
+ self,
+ pretrained: Optional[str] = None,
+ config: Optional[GPT2Config] = None,
+ lora_rank: int = 0,
+ lora_train_bias: str = "none",
+ **kwargs,
+ ) -> None:
if pretrained is not None:
model = GPT2Model.from_pretrained(pretrained)
elif config is not None:
diff --git a/applications/Chat/coati/models/gpt/gpt_rm.py b/applications/Chat/coati/models/gpt/gpt_rm.py
index e52a5a14c1da..8edfc4008466 100644
--- a/applications/Chat/coati/models/gpt/gpt_rm.py
+++ b/applications/Chat/coati/models/gpt/gpt_rm.py
@@ -18,11 +18,13 @@ class GPTRM(RewardModel):
lora_train_bias (str): LoRA bias training mode.
"""
- def __init__(self,
- pretrained: Optional[str] = None,
- config: Optional[GPT2Config] = None,
- lora_rank: int = 0,
- lora_train_bias: str = 'none') -> None:
+ def __init__(
+ self,
+ pretrained: Optional[str] = None,
+ config: Optional[GPT2Config] = None,
+ lora_rank: int = 0,
+ lora_train_bias: str = "none",
+ ) -> None:
if pretrained is not None:
model = GPT2Model.from_pretrained(pretrained)
elif config is not None:
diff --git a/applications/Chat/coati/models/llama/__init__.py b/applications/Chat/coati/models/llama/__init__.py
index 9b2a024afdb2..c87d732538a9 100644
--- a/applications/Chat/coati/models/llama/__init__.py
+++ b/applications/Chat/coati/models/llama/__init__.py
@@ -2,4 +2,4 @@
from .llama_critic import LlamaCritic
from .llama_rm import LlamaRM
-__all__ = ['LlamaActor', 'LlamaCritic', 'LlamaRM']
+__all__ = ["LlamaActor", "LlamaCritic", "LlamaRM"]
diff --git a/applications/Chat/coati/models/llama/llama_actor.py b/applications/Chat/coati/models/llama/llama_actor.py
index 2c7adb390d8b..f1d9406835ca 100644
--- a/applications/Chat/coati/models/llama/llama_actor.py
+++ b/applications/Chat/coati/models/llama/llama_actor.py
@@ -1,7 +1,6 @@
from typing import Optional
-import torch
-from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM
+from transformers import LlamaConfig, LlamaForCausalLM
from ..base import Actor
@@ -18,13 +17,14 @@ class LlamaActor(Actor):
lora_train_bias (str): LoRA bias training mode.
"""
- def __init__(self,
- pretrained: Optional[str] = None,
- config: Optional[LlamaConfig] = None,
- checkpoint: bool = False,
- lora_rank: int = 0,
- lora_train_bias: str = 'none') -> None:
-
+ def __init__(
+ self,
+ pretrained: Optional[str] = None,
+ config: Optional[LlamaConfig] = None,
+ checkpoint: bool = False,
+ lora_rank: int = 0,
+ lora_train_bias: str = "none",
+ ) -> None:
if pretrained is not None:
model = LlamaForCausalLM.from_pretrained(pretrained)
elif config is not None:
diff --git a/applications/Chat/coati/models/llama/llama_critic.py b/applications/Chat/coati/models/llama/llama_critic.py
index a67e5de5def6..000dce17ccf0 100644
--- a/applications/Chat/coati/models/llama/llama_critic.py
+++ b/applications/Chat/coati/models/llama/llama_critic.py
@@ -17,13 +17,14 @@ class LlamaCritic(Critic):
lora_train_bias (str): LoRA bias training mode.
"""
- def __init__(self,
- pretrained: Optional[str] = None,
- config: Optional[LlamaConfig] = None,
- lora_rank: int = 0,
- lora_train_bias: str = 'none',
- **kwargs) -> None:
-
+ def __init__(
+ self,
+ pretrained: Optional[str] = None,
+ config: Optional[LlamaConfig] = None,
+ lora_rank: int = 0,
+ lora_train_bias: str = "none",
+ **kwargs,
+ ) -> None:
if pretrained is not None:
model = LlamaModel.from_pretrained(pretrained)
elif config is not None:
diff --git a/applications/Chat/coati/models/llama/llama_rm.py b/applications/Chat/coati/models/llama/llama_rm.py
index d6b62922686e..43bc9e638dc7 100644
--- a/applications/Chat/coati/models/llama/llama_rm.py
+++ b/applications/Chat/coati/models/llama/llama_rm.py
@@ -1,7 +1,7 @@
from typing import Optional
import torch.nn as nn
-from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel
+from transformers import LlamaConfig, LlamaModel
from ..base import RewardModel
@@ -17,12 +17,13 @@ class LlamaRM(RewardModel):
lora_train_bias (str): LoRA bias training mode.
"""
- def __init__(self,
- pretrained: Optional[str] = None,
- config: Optional[LlamaConfig] = None,
- lora_rank: int = 0,
- lora_train_bias: str = 'none') -> None:
-
+ def __init__(
+ self,
+ pretrained: Optional[str] = None,
+ config: Optional[LlamaConfig] = None,
+ lora_rank: int = 0,
+ lora_train_bias: str = "none",
+ ) -> None:
if pretrained is not None:
model = LlamaModel.from_pretrained(pretrained)
elif config is not None:
diff --git a/applications/Chat/coati/models/lora.py b/applications/Chat/coati/models/lora.py
index 546f675d7d37..e9bd7b2ed8f0 100644
--- a/applications/Chat/coati/models/lora.py
+++ b/applications/Chat/coati/models/lora.py
@@ -1,4 +1,6 @@
+import dataclasses
import math
+import warnings
from typing import Optional
import loralib as lora
@@ -7,9 +9,16 @@
import torch.nn.functional as F
+@dataclasses.dataclass
+class LoRAManager:
+ merge_weights: bool = False
+
+
+LORA_MANAGER = LoRAManager()
+
+
class LoraLinear(lora.LoRALayer, nn.Module):
- """Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear.
- """
+ """Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear."""
def __init__(
self,
@@ -17,16 +26,12 @@ def __init__(
bias: Optional[nn.Parameter],
r: int = 0,
lora_alpha: int = 1,
- lora_dropout: float = 0.,
- fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
- merge_weights: bool = True,
+ lora_dropout: float = 0.0,
+ # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
+ fan_in_fan_out: bool = False,
):
nn.Module.__init__(self)
- lora.LoRALayer.__init__(self,
- r=r,
- lora_alpha=lora_alpha,
- lora_dropout=lora_dropout,
- merge_weights=merge_weights)
+ lora.LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False)
self.weight = weight
self.bias = bias
@@ -47,45 +52,42 @@ def __init__(
self.weight.data = self.weight.data.T
def reset_parameters(self):
- if hasattr(self, 'lora_A'):
- # initialize A the same way as the default for nn.Linear and B to zero
+ if hasattr(self, "lora_A"):
+ # Initialize A with the default values for nn.Linear and set B to zero.
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
def train(self, mode: bool = True):
-
- def T(w):
- return w.T if self.fan_in_fan_out else w
-
- nn.Module.train(self, mode)
- if self.merge_weights and self.merged:
- # Make sure that the weights are not merged
- if self.r > 0:
- if not hasattr(self, "lora_A") or not hasattr(self, "lora_B"):
- # FIXME(csric): temporary fix
- self.lora_A = nn.Parameter(self.weight.new_empty((self.r, self.in_features)))
- self.lora_B = nn.Parameter(self.weight.new_empty((self.out_features, self.r)))
- self.reset_parameters()
- else:
- self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
- self.merged = False
-
- def eval(self):
-
def T(w):
return w.T if self.fan_in_fan_out else w
- nn.Module.eval(self)
- if self.merge_weights and not self.merged:
- # Merge the weights and mark it
- if self.r > 0:
- self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
- delattr(self, 'lora_A')
- delattr(self, 'lora_B')
- self.merged = True
+ self.training = mode
+ if LORA_MANAGER.merge_weights:
+ if mode and self.merged:
+ warnings.warn("Invoke module.train() would unmerge LoRA weights.")
+ raise NotImplementedError("LoRA unmerge is not tested.")
+ # Make sure that the weights are not merged
+ if self.r > 0:
+ if not hasattr(self, "lora_A") or not hasattr(self, "lora_B"):
+ # FIXME(csric): temporary fix
+ self.lora_A = nn.Parameter(self.weight.new_empty((self.r, self.in_features)))
+ self.lora_B = nn.Parameter(self.weight.new_empty((self.out_features, self.r)))
+ self.reset_parameters()
+ else:
+ self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
+ self.merged = False
+ elif not mode and not self.merged:
+ warnings.warn("Invoke module.eval() would merge LoRA weights.")
+ # Merge the weights and mark it
+ if self.r > 0:
+ self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
+ delattr(self, "lora_A")
+ delattr(self, "lora_B")
+ self.merged = True
+
+ return self
def forward(self, x: torch.Tensor):
-
def T(w):
return w.T if self.fan_in_fan_out else w
@@ -99,8 +101,10 @@ def T(w):
def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
- assert lora_rank <= linear.in_features, f'LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})'
- lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False)
+ assert (
+ lora_rank <= linear.in_features
+ ), f"LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})"
+ lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank)
return lora_linear
@@ -112,7 +116,7 @@ def _convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None:
_convert_to_lora_recursively(child, lora_rank)
-def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = 'none') -> nn.Module:
+def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = "none") -> nn.Module:
"""Convert a torch.nn.Module to a LoRA module.
Args:
@@ -140,7 +144,7 @@ class LoRAModule(nn.Module):
Defaults to 'none'.
"""
- def __init__(self, lora_rank: int = 0, lora_train_bias: str = 'none') -> None:
+ def __init__(self, lora_rank: int = 0, lora_train_bias: str = "none") -> None:
super().__init__()
self.lora_rank = lora_rank
self.lora_train_bias = lora_train_bias
diff --git a/applications/Chat/coati/models/loss.py b/applications/Chat/coati/models/loss.py
index 05a0b4821797..687bd0f7bfe7 100644
--- a/applications/Chat/coati/models/loss.py
+++ b/applications/Chat/coati/models/loss.py
@@ -13,6 +13,7 @@ class GPTLMLoss(nn.Module):
def __init__(self):
super().__init__()
+ # NOTE: default ignore_index is -100, which is equal to IGNORE_INDEX in sft_dataset.py
self.loss = nn.CrossEntropyLoss()
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
@@ -31,11 +32,13 @@ def __init__(self, clip_eps: float = 0.2) -> None:
super().__init__()
self.clip_eps = clip_eps
- def forward(self,
- log_probs: torch.Tensor,
- old_log_probs: torch.Tensor,
- advantages: torch.Tensor,
- action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+ def forward(
+ self,
+ log_probs: torch.Tensor,
+ old_log_probs: torch.Tensor,
+ advantages: torch.Tensor,
+ action_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
ratio = (log_probs - old_log_probs).exp()
surr1 = ratio * advantages
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
@@ -55,14 +58,16 @@ def __init__(self, clip_eps: float = 0.4) -> None:
super().__init__()
self.clip_eps = clip_eps
- def forward(self,
- values: torch.Tensor,
- old_values: torch.Tensor,
- reward: torch.Tensor,
- action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+ def forward(
+ self,
+ values: torch.Tensor,
+ old_values: torch.Tensor,
+ reward: torch.Tensor,
+ action_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps)
- surr1 = (values_clipped - reward)**2
- surr2 = (values - reward)**2
+ surr1 = (values_clipped - reward) ** 2
+ surr2 = (values - reward) ** 2
loss = torch.max(surr1, surr2)
loss = loss.mean()
return 0.5 * loss
diff --git a/applications/Chat/coati/models/opt/__init__.py b/applications/Chat/coati/models/opt/__init__.py
index 334f4df0032a..e37d6e45c8fc 100644
--- a/applications/Chat/coati/models/opt/__init__.py
+++ b/applications/Chat/coati/models/opt/__init__.py
@@ -2,4 +2,4 @@
from .opt_critic import OPTCritic
from .opt_rm import OPTRM
-__all__ = ['OPTActor', 'OPTCritic', 'OPTRM']
+__all__ = ["OPTActor", "OPTCritic", "OPTRM"]
diff --git a/applications/Chat/coati/models/opt/opt_actor.py b/applications/Chat/coati/models/opt/opt_actor.py
index c14e4377ffb2..cd8908e13fb8 100644
--- a/applications/Chat/coati/models/opt/opt_actor.py
+++ b/applications/Chat/coati/models/opt/opt_actor.py
@@ -18,12 +18,14 @@ class OPTActor(Actor):
lora_train_bias (str): LoRA bias training mode.
"""
- def __init__(self,
- pretrained: Optional[str] = None,
- config: Optional[OPTConfig] = None,
- checkpoint: bool = False,
- lora_rank: int = 0,
- lora_train_bias: str = 'none') -> None:
+ def __init__(
+ self,
+ pretrained: Optional[str] = None,
+ config: Optional[OPTConfig] = None,
+ checkpoint: bool = False,
+ lora_rank: int = 0,
+ lora_train_bias: str = "none",
+ ) -> None:
if pretrained is not None:
model = OPTForCausalLM.from_pretrained(pretrained)
elif config is not None:
diff --git a/applications/Chat/coati/models/opt/opt_critic.py b/applications/Chat/coati/models/opt/opt_critic.py
index f66c4173fa52..f37d28812c27 100644
--- a/applications/Chat/coati/models/opt/opt_critic.py
+++ b/applications/Chat/coati/models/opt/opt_critic.py
@@ -18,12 +18,14 @@ class OPTCritic(Critic):
lora_train_bias (str): LoRA bias training mode.
"""
- def __init__(self,
- pretrained: Optional[str] = None,
- config: Optional[OPTConfig] = None,
- lora_rank: int = 0,
- lora_train_bias: str = 'none',
- **kwargs) -> None:
+ def __init__(
+ self,
+ pretrained: Optional[str] = None,
+ config: Optional[OPTConfig] = None,
+ lora_rank: int = 0,
+ lora_train_bias: str = "none",
+ **kwargs,
+ ) -> None:
if pretrained is not None:
model = OPTModel.from_pretrained(pretrained)
elif config is not None:
diff --git a/applications/Chat/coati/models/opt/opt_rm.py b/applications/Chat/coati/models/opt/opt_rm.py
index 6f75344e6aae..893708344ad4 100644
--- a/applications/Chat/coati/models/opt/opt_rm.py
+++ b/applications/Chat/coati/models/opt/opt_rm.py
@@ -17,11 +17,13 @@ class OPTRM(RewardModel):
lora_train_bias (str): LoRA bias training mode.
"""
- def __init__(self,
- pretrained: Optional[str] = None,
- config: Optional[OPTConfig] = None,
- lora_rank: int = 0,
- lora_train_bias: str = 'none') -> None:
+ def __init__(
+ self,
+ pretrained: Optional[str] = None,
+ config: Optional[OPTConfig] = None,
+ lora_rank: int = 0,
+ lora_train_bias: str = "none",
+ ) -> None:
if pretrained is not None:
model = OPTModel.from_pretrained(pretrained)
elif config is not None:
diff --git a/applications/Chat/coati/models/utils.py b/applications/Chat/coati/models/utils.py
index 97637d3523b0..1aaef16620d2 100644
--- a/applications/Chat/coati/models/utils.py
+++ b/applications/Chat/coati/models/utils.py
@@ -4,9 +4,9 @@
import torch.nn.functional as F
-def _compute_approx_kl(log_probs: torch.Tensor,
- log_probs_base: torch.Tensor,
- action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+def _compute_approx_kl(
+ log_probs: torch.Tensor, log_probs_base: torch.Tensor, action_mask: Optional[torch.Tensor] = None
+) -> torch.Tensor:
"""
Compute the approximate KL divergence between two distributions.
Schulman blog: http://joschu.net/blog/kl-approx.html
@@ -26,11 +26,13 @@ def _compute_approx_kl(log_probs: torch.Tensor,
return approx_kl
-def compute_reward(r: Union[torch.Tensor, float],
- kl_coef: float,
- log_probs: torch.Tensor,
- log_probs_base: torch.Tensor,
- action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+def compute_reward(
+ r: Union[torch.Tensor, float],
+ kl_coef: float,
+ log_probs: torch.Tensor,
+ log_probs_base: torch.Tensor,
+ action_mask: Optional[torch.Tensor] = None,
+) -> torch.Tensor:
if kl_coef <= 0.0:
return r
kl = _compute_approx_kl(log_probs, log_probs_base, action_mask=action_mask)
@@ -44,18 +46,17 @@ def _log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.
return log_probs_labels.squeeze(-1)
-def calc_action_log_probs(output: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor:
+def calc_action_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor:
"""Calculate action log probs.
Args:
- output (torch.Tensor): Output tensor of Actor.forward.
+ output (torch.Tensor): Output tensor of Actor.forward.logits.
sequences (torch.LongTensor): Input sequences.
num_actions (int): Number of actions.
Returns:
torch.Tensor: Action log probs.
"""
- logits = output['logits']
log_probs = _log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
return log_probs[:, -num_actions:]
diff --git a/applications/Chat/coati/quant/__init__.py b/applications/Chat/coati/quant/__init__.py
index a65a78d07bb8..1765b8091bc3 100644
--- a/applications/Chat/coati/quant/__init__.py
+++ b/applications/Chat/coati/quant/__init__.py
@@ -2,6 +2,6 @@
from .utils import low_resource_init
__all__ = [
- 'llama_load_quant',
- 'low_resource_init',
+ "llama_load_quant",
+ "low_resource_init",
]
diff --git a/applications/Chat/coati/quant/llama_gptq/__init__.py b/applications/Chat/coati/quant/llama_gptq/__init__.py
index 51c8d6316290..51d5233586ad 100644
--- a/applications/Chat/coati/quant/llama_gptq/__init__.py
+++ b/applications/Chat/coati/quant/llama_gptq/__init__.py
@@ -1,5 +1,5 @@
from .loader import load_quant
__all__ = [
- 'load_quant',
+ "load_quant",
]
diff --git a/applications/Chat/coati/quant/llama_gptq/loader.py b/applications/Chat/coati/quant/llama_gptq/loader.py
index 5353dc8a2ea3..50486337a7ab 100644
--- a/applications/Chat/coati/quant/llama_gptq/loader.py
+++ b/applications/Chat/coati/quant/llama_gptq/loader.py
@@ -11,14 +11,15 @@ def load_quant(model: nn.Module, checkpoint: str, wbits: int, groupsize: int):
# ignore lm head
layers = find_layers(model)
- for name in ['lm_head']:
+ for name in ["lm_head"]:
if name in layers:
del layers[name]
make_quant(model, layers, wbits, groupsize)
- if checkpoint.endswith('.safetensors'):
+ if checkpoint.endswith(".safetensors"):
from safetensors.torch import load_file as safe_load
+
model.load_state_dict(safe_load(checkpoint))
else:
model.load_state_dict(torch.load(checkpoint))
diff --git a/applications/Chat/coati/quant/llama_gptq/model_utils.py b/applications/Chat/coati/quant/llama_gptq/model_utils.py
index 62db171abb52..18e4e4761500 100644
--- a/applications/Chat/coati/quant/llama_gptq/model_utils.py
+++ b/applications/Chat/coati/quant/llama_gptq/model_utils.py
@@ -1,13 +1,12 @@
# copied from https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/past/modelutils.py
-import torch
import torch.nn as nn
-def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
+def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=""):
if type(module) in layers:
return {name: module}
res = {}
for name1, child in module.named_children():
- res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1))
+ res.update(find_layers(child, layers=layers, name=name + "." + name1 if name != "" else name1))
return res
diff --git a/applications/Chat/coati/quant/llama_gptq/quant.py b/applications/Chat/coati/quant/llama_gptq/quant.py
index f7d5b7ce4bd8..5a7e2e72dfc5 100644
--- a/applications/Chat/coati/quant/llama_gptq/quant.py
+++ b/applications/Chat/coati/quant/llama_gptq/quant.py
@@ -13,14 +13,13 @@ def quantize(x, scale, zero, maxq):
class Quantizer(nn.Module):
-
def __init__(self, shape=1):
super(Quantizer, self).__init__()
- self.register_buffer('maxq', torch.tensor(0))
- self.register_buffer('scale', torch.zeros(shape))
- self.register_buffer('zero', torch.zeros(shape))
+ self.register_buffer("maxq", torch.tensor(0))
+ self.register_buffer("scale", torch.zeros(shape))
+ self.register_buffer("zero", torch.zeros(shape))
- def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=.8):
+ def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=0.8):
self.maxq = torch.tensor(2**bits - 1)
self.perchannel = perchannel
self.sym = sym
@@ -68,7 +67,7 @@ def find_params(self, x, weight=False):
self.zero = torch.round(-xmin / self.scale)
if self.mse:
- best = torch.full([x.shape[0]], float('inf'), device=dev)
+ best = torch.full([x.shape[0]], float("inf"), device=dev)
for i in range(int(self.maxshrink * self.grid)):
p = 1 - i / self.grid
xmin1 = p * xmin
@@ -123,13 +122,12 @@ def ready(self):
try:
import quant_cuda
except:
- print('CUDA extension not installed.')
+ print("CUDA extension not installed.")
# Assumes layer is perfectly divisible into 256 * 256 blocks
class QuantLinear(nn.Module):
-
def __init__(self, bits, groupsize, infeatures, outfeatures):
super().__init__()
if bits not in [2, 3, 4, 8]:
@@ -142,11 +140,11 @@ def __init__(self, bits, groupsize, infeatures, outfeatures):
groupsize = groupsize if groupsize != -1 else infeatures
self.groupsize = groupsize
self.register_buffer(
- 'qzeros', torch.zeros((math.ceil(infeatures / groupsize), outfeatures // 256 * (bits * 8)),
- dtype=torch.int))
- self.register_buffer('scales', torch.zeros((math.ceil(infeatures / groupsize), outfeatures)))
- self.register_buffer('bias', torch.zeros(outfeatures))
- self.register_buffer('qweight', torch.zeros((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int))
+ "qzeros", torch.zeros((math.ceil(infeatures / groupsize), outfeatures // 256 * (bits * 8)), dtype=torch.int)
+ )
+ self.register_buffer("scales", torch.zeros((math.ceil(infeatures / groupsize), outfeatures)))
+ self.register_buffer("bias", torch.zeros(outfeatures))
+ self.register_buffer("qweight", torch.zeros((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int))
self._initialized_quant_state = False
def pack(self, linear, scales, zeros):
@@ -161,8 +159,10 @@ def pack(self, linear, scales, zeros):
for idx in range(self.infeatures):
g_idx = idx // self.groupsize
intweight.append(
- torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx]) / self.scales[g_idx]).to(torch.int)[:,
- None])
+ torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx]) / self.scales[g_idx]).to(torch.int)[
+ :, None
+ ]
+ )
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(np.uint32)
@@ -271,13 +271,13 @@ def forward(self, x):
return y.reshape(outshape)
-def make_quant(module, names, bits, groupsize, name=''):
+def make_quant(module, names, bits, groupsize, name=""):
if isinstance(module, QuantLinear):
return
for attr in dir(module):
tmp = getattr(module, attr)
- name1 = name + '.' + attr if name != '' else attr
+ name1 = name + "." + attr if name != "" else attr
if name1 in names:
setattr(module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features))
for name1, child in module.named_children():
- make_quant(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1)
+ make_quant(child, names, bits, groupsize, name + "." + name1 if name != "" else name1)
diff --git a/applications/Chat/coati/quant/utils.py b/applications/Chat/coati/quant/utils.py
index 01b8cff0add1..d102bb30f52d 100644
--- a/applications/Chat/coati/quant/utils.py
+++ b/applications/Chat/coati/quant/utils.py
@@ -9,8 +9,7 @@ def _noop(*args, **kwargs):
@contextmanager
def low_resource_init():
- """This context manager disables weight initialization and sets the default float dtype to half.
- """
+ """This context manager disables weight initialization and sets the default float dtype to half."""
old_kaiming_uniform_ = torch.nn.init.kaiming_uniform_
old_uniform_ = torch.nn.init.uniform_
old_normal_ = torch.nn.init.normal_
diff --git a/applications/Chat/coati/ray/callbacks/base.py b/applications/Chat/coati/ray/callbacks/base.py
index 3306150a41ff..8c5bd8a67776 100644
--- a/applications/Chat/coati/ray/callbacks/base.py
+++ b/applications/Chat/coati/ray/callbacks/base.py
@@ -5,7 +5,7 @@
class TrainerCallback(ABC):
"""
- Base callback class. It defines the interface for callbacks.
+ Base callback class. It defines the interface for callbacks.
"""
def on_fit_start(self) -> None:
@@ -40,7 +40,6 @@ def on_update_end(self) -> None:
class MakerCallback(ABC):
-
def on_loop_start(self) -> None:
pass
diff --git a/applications/Chat/coati/ray/callbacks/performance_evaluator.py b/applications/Chat/coati/ray/callbacks/performance_evaluator.py
index d3df8f9ae3e0..18798bce7dce 100644
--- a/applications/Chat/coati/ray/callbacks/performance_evaluator.py
+++ b/applications/Chat/coati/ray/callbacks/performance_evaluator.py
@@ -30,10 +30,9 @@ def all_reduce_mean(x: float, world_size: int) -> float:
class Timer:
-
def __init__(self) -> None:
self.start_time: Optional[float] = None
- self.duration: float = 0.
+ self.duration: float = 0.0
def start(self) -> None:
self.start_time = time()
@@ -42,13 +41,13 @@ def end(self) -> None:
self.duration += time() - self.start_time
def reset(self) -> None:
- self.duration = 0.
+ self.duration = 0.0
class ExperienceMakerPerformanceEvaluator(MakerCallback):
-
- def __init__(self, actor_num_params: int, critic_num_params: int, initial_model_num_params: int,
- reward_model_num_params: int) -> None:
+ def __init__(
+ self, actor_num_params: int, critic_num_params: int, initial_model_num_params: int, reward_model_num_params: int
+ ) -> None:
super().__init__()
self.world_size = get_world_size()
self.actor_num_params = actor_num_params
@@ -63,7 +62,7 @@ def __init__(self, actor_num_params: int, critic_num_params: int, initial_model_
self.make_experience_flop: int = 0
print_rank_0(
- f'ExperienceMaker actor: {actor_num_params/1024**3:.2f}B, critic: {critic_num_params/1024**3:.2f}B, initial model: {initial_model_num_params/1024**3:.2f}B, reward model: {reward_model_num_params/1024**3:.2f}B, world size: {self.world_size}'
+ f"ExperienceMaker actor: {actor_num_params/1024**3:.2f}B, critic: {critic_num_params/1024**3:.2f}B, initial model: {initial_model_num_params/1024**3:.2f}B, reward model: {reward_model_num_params/1024**3:.2f}B, world size: {self.world_size}"
)
def on_make_experience_start(self) -> None:
@@ -110,27 +109,29 @@ def on_loop_end(self) -> None:
avg_throughput = self.total_samples * self.world_size / (avg_overall_duration + 1e-12)
avg_make_experience_tflops = self.make_experience_flop / 1e12 / (avg_make_experience_duration + 1e-12)
avg_time_per_sample = (avg_overall_duration + 1e-12) / (self.total_samples * self.world_size)
- avg_make_experience_time_per_sample = (avg_make_experience_duration + 1e-12) / \
- (self.total_samples * self.world_size)
+ avg_make_experience_time_per_sample = (avg_make_experience_duration + 1e-12) / (
+ self.total_samples * self.world_size
+ )
avg_send_time_per_sample = (avg_send_duration + 1e-12) / (self.total_samples * self.world_size)
print_rank_0(
- 'Making Experience Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n'
- + f'TFLOPS per GPU: {avg_make_experience_tflops:.3f}\n'
- + f'Sample time (overall): {avg_time_per_sample:.3f} s\n'
- + f'Sample time (make experience): {avg_make_experience_time_per_sample:.3f} s, {avg_make_experience_time_per_sample/avg_time_per_sample*100:.2f}%\n'
-
- + f'Sample time (send): {avg_send_time_per_sample:.3f} s, {avg_send_time_per_sample/avg_time_per_sample*100:.2f}%\n'
+ "Making Experience Performance Summary:\n"
+ + f"Throughput: {avg_throughput:.3f} samples/sec\n"
+ + f"TFLOPS per GPU: {avg_make_experience_tflops:.3f}\n"
+ + f"Sample time (overall): {avg_time_per_sample:.3f} s\n"
+ + f"Sample time (make experience): {avg_make_experience_time_per_sample:.3f} s, {avg_make_experience_time_per_sample/avg_time_per_sample*100:.2f}%\n"
+ + f"Sample time (send): {avg_send_time_per_sample:.3f} s, {avg_send_time_per_sample/avg_time_per_sample*100:.2f}%\n"
)
class TrainerPerformanceEvaluator(TrainerCallback):
-
- def __init__(self,
- actor_num_params: int,
- critic_num_params: int,
- enable_grad_checkpoint: bool = False,
- ignore_first_episodes: int = 1) -> None:
+ def __init__(
+ self,
+ actor_num_params: int,
+ critic_num_params: int,
+ enable_grad_checkpoint: bool = False,
+ ignore_first_episodes: int = 1,
+ ) -> None:
super().__init__()
self.world_size = get_world_size()
self.actor_num_params = actor_num_params
@@ -146,7 +147,7 @@ def __init__(self,
self.learn_flop: int = 0
print_rank_0(
- f'Trainer actor: {self.actor_num_params/1024**3:.2f}B, critic: {self.critic_num_params/1024**3:.2f}B, world size: {self.world_size}'
+ f"Trainer actor: {self.actor_num_params/1024**3:.2f}B, critic: {self.critic_num_params/1024**3:.2f}B, world size: {self.world_size}"
)
def on_episode_start(self, episodes: int) -> None:
@@ -191,7 +192,7 @@ def on_update_end(self) -> None:
def on_fit_end(self) -> None:
if self.total_samples == 0:
- print_rank_0('No samples are collected, skip trainer performance evaluation')
+ print_rank_0("No samples are collected, skip trainer performance evaluation")
return
avg_train_duration = all_reduce_mean(self.batch_timer.duration, self.world_size)
avg_update_duration = all_reduce_mean(self.update_timer.duration, self.world_size)
@@ -204,9 +205,10 @@ def on_fit_end(self) -> None:
avg_update_time_per_sample = (avg_update_duration + 1e-12) / (self.total_samples * self.world_size)
print_rank_0(
- 'Learning Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n'
- + f'TFLOPS per GPU: {avg_learn_tflops:.3f}\n' + f'Sample time (overall): {avg_time_per_sample:.3f} s\n'
- + f'Sample time (train): {avg_train_time_per_sample:.3f} s, {avg_train_time_per_sample/avg_time_per_sample*100:.2f}%\n'
-
- + f'Sample time (update): {avg_update_time_per_sample:.3f} s, {avg_update_time_per_sample/avg_time_per_sample*100:.2f}%\n'
+ "Learning Performance Summary:\n"
+ + f"Throughput: {avg_throughput:.3f} samples/sec\n"
+ + f"TFLOPS per GPU: {avg_learn_tflops:.3f}\n"
+ + f"Sample time (overall): {avg_time_per_sample:.3f} s\n"
+ + f"Sample time (train): {avg_train_time_per_sample:.3f} s, {avg_train_time_per_sample/avg_time_per_sample*100:.2f}%\n"
+ + f"Sample time (update): {avg_update_time_per_sample:.3f} s, {avg_update_time_per_sample/avg_time_per_sample*100:.2f}%\n"
)
diff --git a/applications/Chat/coati/ray/detached_replay_buffer.py b/applications/Chat/coati/ray/detached_replay_buffer.py
index 7b9df2ee139b..92dab17292f7 100644
--- a/applications/Chat/coati/ray/detached_replay_buffer.py
+++ b/applications/Chat/coati/ray/detached_replay_buffer.py
@@ -1,22 +1,17 @@
-import asyncio
-import copy
-import random
-from threading import Lock
-from typing import Any, List
+from typing import List
-import ray
import torch
-from coati.experience_buffer import ExperienceBuffer
from coati.experience_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
from coati.experience_maker.base import Experience
+
# from torch.multiprocessing import Queue
from ray.util.queue import Queue
class DetachedReplayBuffer:
- '''
+ """
Detached replay buffer. Share Experience across workers on the same node.
- Therefore a trainer node is expected to have only one instance.
+ Therefore, a trainer node is expected to have only one instance.
It is ExperienceMakerHolder's duty to call append(exp) method, remotely.
Args:
@@ -24,7 +19,7 @@ class DetachedReplayBuffer:
tp_world_size: Number of workers in the same tp group
limit: Limit of number of experience sample BATCHs. A number <= 0 means unlimited. Defaults to 0.
cpu_offload: Whether to offload experience to cpu when sampling. Defaults to True.
- '''
+ """
def __init__(self, sample_batch_size: int, limit: int = 0) -> None:
self.sample_batch_size = sample_batch_size
@@ -34,23 +29,23 @@ def __init__(self, sample_batch_size: int, limit: int = 0) -> None:
@torch.no_grad()
def append(self, experience: Experience) -> None:
- '''
+ """
Expected to be called remotely.
- '''
+ """
items = split_experience_batch(experience)
self.extend(items)
@torch.no_grad()
def extend(self, items: List[BufferItem]) -> None:
- '''
+ """
Expected to be called remotely.
- '''
+ """
self.batch_collector.extend(items)
while len(self.batch_collector) >= self.sample_batch_size:
- items = self.batch_collector[:self.sample_batch_size]
+ items = self.batch_collector[: self.sample_batch_size]
experience = make_experience_batch(items)
self.items.put(experience, block=True)
- self.batch_collector = self.batch_collector[self.sample_batch_size:]
+ self.batch_collector = self.batch_collector[self.sample_batch_size :]
def clear(self) -> None:
# self.items.close()
diff --git a/applications/Chat/coati/ray/detached_trainer_base.py b/applications/Chat/coati/ray/detached_trainer_base.py
index 90399781187a..fcf0a472df9e 100644
--- a/applications/Chat/coati/ray/detached_trainer_base.py
+++ b/applications/Chat/coati/ray/detached_trainer_base.py
@@ -1,6 +1,6 @@
import os
from abc import ABC, abstractmethod
-from typing import Any, Callable, Dict, Iterable, List, Optional, Union
+from typing import Any, Dict, List
import ray
import torch
@@ -15,7 +15,7 @@
class DetachedTrainer(ABC):
- '''
+ """
Base class for detached rlhf trainers.
'detach' means that the experience maker is detached compared to a normal Trainer.
Please set name attribute during init:
@@ -28,15 +28,17 @@ class DetachedTrainer(ABC):
callbacks (List[Callback], defaults to []): the callbacks to call during training process
generate_kwargs (dict, optional): the kwargs to use while model generating
- '''
-
- def __init__(self,
- experience_maker_holder_name_list: List[str],
- train_batch_size: int = 8,
- buffer_limit: int = 0,
- dataloader_pin_memory: bool = True,
- callbacks: List[TrainerCallback] = [],
- debug: bool = False) -> None:
+ """
+
+ def __init__(
+ self,
+ experience_maker_holder_name_list: List[str],
+ train_batch_size: int = 8,
+ buffer_limit: int = 0,
+ dataloader_pin_memory: bool = True,
+ callbacks: List[TrainerCallback] = [],
+ debug: bool = False,
+ ) -> None:
super().__init__()
self.detached_replay_buffer = DetachedReplayBuffer(train_batch_size, limit=buffer_limit)
self.dataloader_pin_memory = dataloader_pin_memory
@@ -67,18 +69,16 @@ def training_step(self, experience: Experience) -> Dict[str, Any]:
def _learn(self, update_steps: int, train_epochs: int) -> None:
data = []
# warmup
- pbar = tqdm(range(update_steps), desc=f'Train epoch [1/{train_epochs}]', disable=not is_rank_0())
+ pbar = tqdm(range(update_steps), desc=f"Train epoch [1/{train_epochs}]", disable=not is_rank_0())
self._on_epoch_start(0)
self._learn_epoch(pbar, data)
self._on_epoch_end(0)
# item is already a batch
- dataloader = DataLoader(data,
- batch_size=1,
- shuffle=True,
- pin_memory=self.dataloader_pin_memory,
- collate_fn=lambda x: x[0])
+ dataloader = DataLoader(
+ data, batch_size=1, shuffle=True, pin_memory=self.dataloader_pin_memory, collate_fn=lambda x: x[0]
+ )
for epoch in range(1, train_epochs):
- pbar = tqdm(dataloader, desc=f'Train epoch [{epoch + 1}/{train_epochs}]', disable=not is_rank_0())
+ pbar = tqdm(dataloader, desc=f"Train epoch [{epoch + 1}/{train_epochs}]", disable=not is_rank_0())
self._on_epoch_start(epoch)
self._learn_epoch(pbar, data)
self._on_epoch_end(epoch)
@@ -104,7 +104,7 @@ def _learn_epoch(self, pbar: tqdm, data: List[Experience]) -> None:
def fit(self, total_steps: int, update_steps: int, train_epochs: int = 1) -> None:
self._on_fit_start()
- for i in tqdm(range(total_steps // update_steps), desc='Trainer', disable=not is_rank_0()):
+ for i in tqdm(range(total_steps // update_steps), desc="Trainer", disable=not is_rank_0()):
self._on_episode_start(i)
self._learn(update_steps, train_epochs)
self._on_update_start()
diff --git a/applications/Chat/coati/ray/detached_trainer_ppo.py b/applications/Chat/coati/ray/detached_trainer_ppo.py
index 2f2aa0e29579..ef84a1ddba48 100644
--- a/applications/Chat/coati/ray/detached_trainer_ppo.py
+++ b/applications/Chat/coati/ray/detached_trainer_ppo.py
@@ -1,12 +1,11 @@
-from typing import Any, Callable, Dict, List, Optional, Tuple
+from typing import Callable, Dict, List, Tuple
import ray
import torch
-from coati.experience_maker import Experience, NaiveExperienceMaker
+from coati.experience_maker import Experience
from coati.models.base import Actor, Critic
from coati.models.loss import PolicyLoss, ValueLoss
-from coati.trainer.callbacks import Callback
-from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy, Strategy
+from coati.trainer.strategies import GeminiStrategy, LowLevelZeroStrategy, Strategy
from torch.optim import Adam
from colossalai.nn.optimizer import HybridAdam
@@ -14,27 +13,14 @@
from .callbacks import TrainerCallback, TrainerPerformanceEvaluator
from .detached_trainer_base import DetachedTrainer
from .lora_constructor import LoRAConstructor
-from .utils import (
- get_actor_from_args,
- get_critic_from_args,
- get_model_numel,
- get_rank,
- get_strategy_from_args,
- is_rank_0,
- set_dist_env,
- state_dict_to,
-)
+from .utils import get_model_numel, get_rank, set_dist_env, state_dict_to
-@ray.remote(concurrency_groups={
- "buffer_length": 1,
- "buffer_append": 1,
- "buffer_sample": 1,
- "model_io": 1,
- "compute": 1
-})
+@ray.remote(
+ concurrency_groups={"buffer_length": 1, "buffer_append": 1, "buffer_sample": 1, "model_io": 1, "compute": 1}
+)
class DetachedPPOTrainer(DetachedTrainer):
- '''
+ """
Detached Trainer for PPO algorithm
Args:
strategy (Strategy): the strategy to use for training
@@ -52,7 +38,7 @@ class DetachedPPOTrainer(DetachedTrainer):
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
callbacks (List[Callback], defaults to []): the callbacks to call during training process
generate_kwargs (dict, optional): the kwargs to use while model generating
- '''
+ """
def __init__(
self,
@@ -92,21 +78,24 @@ def __init__(
self.actor_optim = Adam(self.actor.parameters(), lr=1e-7)
self.critic_optim = Adam(self.critic.parameters(), lr=1e-7)
- (self.actor, self.actor_optim), (self.critic, self.critic_optim) = \
- self.strategy.prepare((self.actor, self.actor_optim), (self.critic, self.critic_optim))
+ (self.actor, self.actor_optim), (self.critic, self.critic_optim) = self.strategy.prepare(
+ (self.actor, self.actor_optim), (self.critic, self.critic_optim)
+ )
# configure trainer
self.actor_loss_fn = PolicyLoss(eps_clip)
self.critic_loss_fn = ValueLoss(value_clip)
- super().__init__(experience_maker_holder_name_list,
- train_batch_size=train_batch_size,
- buffer_limit=buffer_limit,
- dataloader_pin_memory=dataloader_pin_memory,
- callbacks=callbacks,
- debug=debug)
+ super().__init__(
+ experience_maker_holder_name_list,
+ train_batch_size=train_batch_size,
+ buffer_limit=buffer_limit,
+ dataloader_pin_memory=dataloader_pin_memory,
+ callbacks=callbacks,
+ debug=debug,
+ )
if self._debug:
- print(f'[trainer{get_rank()}] will send state dict to {experience_maker_holder_name_list}')
+ print(f"[trainer{get_rank()}] will send state dict to {experience_maker_holder_name_list}")
self._update_lora_weights = update_lora_weights
@@ -115,7 +104,7 @@ def __init__(
def _update_remote_makers(self, fully_update: bool = False, **config):
# TODO: balance duties
if not fully_update:
- config['requires_grad_only'] = True
+ config["requires_grad_only"] = True
self.update_target_holder_list()
# mark start, ensure order
tasks = []
@@ -131,7 +120,9 @@ def _update_remote_makers(self, fully_update: bool = False, **config):
target_holder.update_experience_maker.remote(
new_actor_state_dict=state_dict_shard,
new_actor_lora_config_dict=self._get_model_lora_config_dict(self.actor),
- fully_update=fully_update))
+ fully_update=fully_update,
+ )
+ )
# sending loop
for state_dict_shard in self._get_model_state_dict_shard(self.critic, fully_update=fully_update, **config):
for target_holder in self.target_holder_list:
@@ -139,7 +130,9 @@ def _update_remote_makers(self, fully_update: bool = False, **config):
target_holder.update_experience_maker.remote(
new_critic_state_dict=state_dict_shard,
new_critic_lora_config_dict=self._get_model_lora_config_dict(self.critic),
- fully_update=fully_update))
+ fully_update=fully_update,
+ )
+ )
ray.get(tasks)
# mark end
for target_holder in self.target_holder_list:
@@ -152,26 +145,24 @@ def training_step(self, experience: Experience) -> Dict[str, float]:
num_actions = experience.action_mask.size(1)
action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask)
- actor_loss = self.actor_loss_fn(action_log_probs,
- experience.action_log_probs,
- experience.advantages,
- action_mask=experience.action_mask)
+ actor_loss = self.actor_loss_fn(
+ action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask
+ )
self.strategy.backward(actor_loss, self.actor, self.actor_optim)
self.strategy.optimizer_step(self.actor_optim)
self.actor_optim.zero_grad()
- values = self.critic(experience.sequences,
- action_mask=experience.action_mask,
- attention_mask=experience.attention_mask)
- critic_loss = self.critic_loss_fn(values,
- experience.values,
- experience.reward,
- action_mask=experience.action_mask)
+ values = self.critic(
+ experience.sequences, action_mask=experience.action_mask, attention_mask=experience.attention_mask
+ )
+ critic_loss = self.critic_loss_fn(
+ values, experience.values, experience.reward, action_mask=experience.action_mask
+ )
self.strategy.backward(critic_loss, self.critic, self.critic_optim)
self.strategy.optimizer_step(self.critic_optim)
self.critic_optim.zero_grad()
- return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()}
+ return {"actor_loss": actor_loss.item(), "critic_loss": critic_loss.item()}
def strategy_save_actor(self, path: str, only_rank0: bool = False) -> None:
self.strategy.save_model(self.actor, path, only_rank0)
diff --git a/applications/Chat/coati/ray/experience_maker_holder.py b/applications/Chat/coati/ray/experience_maker_holder.py
index 13314bdafd5f..4d290f4aba88 100644
--- a/applications/Chat/coati/ray/experience_maker_holder.py
+++ b/applications/Chat/coati/ray/experience_maker_holder.py
@@ -1,53 +1,49 @@
import os
import time
import tracemalloc
-from copy import deepcopy
from threading import Lock
-from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
+from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
import ray
import torch
-import torch.nn as nn
-from coati.experience_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
-from coati.experience_maker import Experience, ExperienceMaker, NaiveExperienceMaker
+from coati.experience_buffer.utils import split_experience_batch
+from coati.experience_maker import Experience, NaiveExperienceMaker
from coati.models.base import Actor, Critic, RewardModel
-from coati.trainer.callbacks import Callback
from coati.trainer.strategies import Strategy
-from coati.trainer.strategies.sampler import DistributedSampler
-from ray.exceptions import GetTimeoutError
from torch import Tensor
from tqdm import tqdm
from .callbacks import ExperienceMakerPerformanceEvaluator, MakerCallback
from .lora_constructor import LoRAConstructor
-from .utils import get_model_numel, get_rank, get_world_size, is_rank_0, set_dist_env, state_dict_to
+from .utils import get_model_numel, get_rank, is_rank_0, set_dist_env, state_dict_to
@ray.remote(concurrency_groups={"experience_io": 1, "model_io": 1, "compute": 1})
class ExperienceMakerHolder:
- '''
+ """
Args:
detached_trainer_name_list: str list to get ray actor handles
strategy:
kl_coef: the coefficient of kl divergence loss
sync_models_from_trainers: whether to sync models from trainers. If True, you must call sync_models_to_remote_makers() in trainers to sync models.
- '''
+ """
def __init__(
- self,
- detached_trainer_name_list: List[str],
- strategy_fn: Callable[[], Strategy],
+ self,
+ detached_trainer_name_list: List[str],
+ strategy_fn: Callable[[], Strategy],
# a function returns (actor, critic, reward_model, initial_model)
- model_fn: Callable[[], Tuple[Actor, Critic, RewardModel, Actor]],
- env_info: Dict[str, str] = None,
- sync_models_from_trainers: bool = False,
- buffer_cpu_offload: bool = True,
- kl_coef: float = 0.1,
- callbacks: List[MakerCallback] = [],
- eval_performance: bool = False,
- debug: bool = False,
- update_lora_weights: bool = False,
- **generate_kwargs):
+ model_fn: Callable[[], Tuple[Actor, Critic, RewardModel, Actor]],
+ env_info: Dict[str, str] = None,
+ sync_models_from_trainers: bool = False,
+ buffer_cpu_offload: bool = True,
+ kl_coef: float = 0.1,
+ callbacks: List[MakerCallback] = [],
+ eval_performance: bool = False,
+ debug: bool = False,
+ update_lora_weights: bool = False,
+ **generate_kwargs,
+ ):
# set environment variables
if env_info:
set_dist_env(env_info=env_info)
@@ -66,8 +62,9 @@ def __init__(
critic_numel = get_model_numel(critic)
initial_model_numel = get_model_numel(initial_model)
reward_model_numel = get_model_numel(reward_model)
- evaluator = ExperienceMakerPerformanceEvaluator(actor_numel, critic_numel, initial_model_numel,
- reward_model_numel)
+ evaluator = ExperienceMakerPerformanceEvaluator(
+ actor_numel, critic_numel, initial_model_numel, reward_model_numel
+ )
callbacks = callbacks + [evaluator]
actor, critic, reward_model, initial_model = self.strategy.prepare(actor, critic, reward_model, initial_model)
@@ -89,9 +86,9 @@ def __init__(
self._target_idx = 0
if self._debug:
- print(f'[maker{get_rank()}] will send items to {self._detached_trainer_name_list}')
+ print(f"[maker{get_rank()}] will send items to {self._detached_trainer_name_list}")
if not self._is_fully_initialized:
- print(f'[maker{get_rank()}] Waiting for INIT')
+ print(f"[maker{get_rank()}] Waiting for INIT")
def _get_ready(self):
while not self._fully_initialized():
@@ -136,7 +133,7 @@ def _inference_step(self, batch) -> None:
self._on_make_experience_end(experience)
self._on_send_start()
if self.buffer_cpu_offload:
- experience.to_device('cpu')
+ experience.to_device("cpu")
self._send_items(experience)
self._on_send_end()
self._on_batch_end()
@@ -155,7 +152,7 @@ def workingloop(self, dataloader_fn: Callable[[], Iterable], num_epochs: int = 1
if num_steps > 0:
# ignore num epochs
it = iter(dataloader)
- for _ in tqdm(range(num_steps), desc='ExperienceMaker', disable=not is_rank_0()):
+ for _ in tqdm(range(num_steps), desc="ExperienceMaker", disable=not is_rank_0()):
try:
batch = next(it)
except StopIteration:
@@ -163,7 +160,7 @@ def workingloop(self, dataloader_fn: Callable[[], Iterable], num_epochs: int = 1
batch = next(it)
self._inference_step(batch)
else:
- with tqdm(total=num_epochs * len(dataloader), desc='ExperienceMaker', disable=not is_rank_0()) as pbar:
+ with tqdm(total=num_epochs * len(dataloader), desc="ExperienceMaker", disable=not is_rank_0()) as pbar:
for _ in range(num_epochs):
for batch in dataloader:
self._inference_step(batch)
@@ -171,22 +168,24 @@ def workingloop(self, dataloader_fn: Callable[[], Iterable], num_epochs: int = 1
self._on_loop_end()
@ray.method(concurrency_group="model_io")
- def update_experience_maker(self,
- new_actor_state_dict: Dict[str, Any] = None,
- new_actor_lora_config_dict: Dict[str, Any] = None,
- new_critic_state_dict: Dict[str, Any] = None,
- new_critic_lora_config_dict: Dict[str, Any] = None,
- fully_update: bool = False,
- chunk_start: bool = None,
- chunk_end: bool = None):
- '''
- called by trainer
- chunk_start: Set True at the first call. Before sending state_dict calls
- chunk_end: Set True at the last call. After sending state_dict calls.
- fully_update: Set True if you want to sync models when initializing
-
- TODO: load_state_dict integrate with model-sharding strategy
- '''
+ def update_experience_maker(
+ self,
+ new_actor_state_dict: Dict[str, Any] = None,
+ new_actor_lora_config_dict: Dict[str, Any] = None,
+ new_critic_state_dict: Dict[str, Any] = None,
+ new_critic_lora_config_dict: Dict[str, Any] = None,
+ fully_update: bool = False,
+ chunk_start: bool = None,
+ chunk_end: bool = None,
+ ):
+ """
+ called by trainer
+ chunk_start: Set True at the first call. Before sending state_dict calls
+ chunk_end: Set True at the last call. After sending state_dict calls.
+ fully_update: Set True if you want to sync models when initializing
+
+ TODO: load_state_dict integrate with model-sharding strategy
+ """
_watch_memory = self._debug
if chunk_start:
if self._debug:
@@ -202,18 +201,22 @@ def update_experience_maker(self,
else:
new_actor_state_dict = state_dict_to(new_actor_state_dict, device=torch.cuda.current_device())
state_dict_increase = self.actor_lora_constructor.reconstruct_increase(
- new_actor_state_dict, new_actor_lora_config_dict)
+ new_actor_state_dict, new_actor_lora_config_dict
+ )
self.actor_lora_constructor.load_state_dict_increase(
- self.experience_maker.actor.model, state_dict_increase)
+ self.experience_maker.actor.model, state_dict_increase
+ )
if new_critic_state_dict is not None:
if not self._update_lora_weights or fully_update:
self.experience_maker.critic.load_state_dict(new_critic_state_dict, strict=False)
else:
new_critic_state_dict = state_dict_to(new_critic_state_dict, device=torch.cuda.current_device())
state_dict_increase = self.critic_lora_constructor.reconstruct_increase(
- new_critic_state_dict, new_critic_lora_config_dict)
+ new_critic_state_dict, new_critic_lora_config_dict
+ )
self.critic_lora_constructor.load_state_dict_increase(
- self.experience_maker.critic, state_dict_increase)
+ self.experience_maker.critic, state_dict_increase
+ )
# the lock must be released after both actor and critic being updated
if chunk_end:
@@ -262,10 +265,10 @@ def _set_default_generate_kwargs(generate_kwargs: dict, actor: Actor) -> None:
origin_model = actor.model
new_kwargs = {**generate_kwargs}
# use huggingface models method directly
- if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'):
- new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation
+ if "prepare_inputs_fn" not in generate_kwargs and hasattr(origin_model, "prepare_inputs_for_generation"):
+ new_kwargs["prepare_inputs_fn"] = origin_model.prepare_inputs_for_generation
- if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(origin_model, '_update_model_kwargs_for_generation'):
- new_kwargs['update_model_kwargs_fn'] = origin_model._update_model_kwargs_for_generation
+ if "update_model_kwargs_fn" not in generate_kwargs and hasattr(origin_model, "_update_model_kwargs_for_generation"):
+ new_kwargs["update_model_kwargs_fn"] = origin_model._update_model_kwargs_for_generation
return new_kwargs
diff --git a/applications/Chat/coati/ray/lora_constructor.py b/applications/Chat/coati/ray/lora_constructor.py
index a98545d4d751..8e9f78700e29 100644
--- a/applications/Chat/coati/ray/lora_constructor.py
+++ b/applications/Chat/coati/ray/lora_constructor.py
@@ -1,11 +1,9 @@
from collections import OrderedDict
from dataclasses import dataclass
-from typing import Any, Callable, Dict, List, Optional
+from typing import Any, Dict
-import torch
import torch.nn as nn
from coati.models.lora import LoraLinear
-from loralib.layers import LoRALayer
@dataclass
@@ -17,7 +15,7 @@ class LoRAConfig:
class LoRAConstructor:
- '''
+ """
Tools for reconstructing a model from a remote LoRA model.
(Transferring only LoRA data costs much less!)
Usage:
@@ -36,7 +34,7 @@ class LoRAConstructor:
Step 5 (Receiver):
load_state_dict_increase()
- '''
+ """
def __init__(self):
self.lora_config_dict = None
@@ -45,10 +43,10 @@ def register_lora_config(self, lora_config_dict: Dict[str, Any]):
self.lora_config_dict = lora_config_dict
def reconstruct_increase(self, state_dict_lora: Dict[str, Any], lora_config_dict: Dict[str, Any]):
- '''
- xxx.lora_A, xxx.lora_B -->> xxx.weight
- Warning: the xxx.weight here is the increment actually.
- '''
+ """
+ xxx.lora_A, xxx.lora_B -->> xxx.weight
+ Warning: the xxx.weight here is the increment actually.
+ """
if lora_config_dict is not None:
self.register_lora_config(lora_config_dict)
@@ -56,24 +54,25 @@ def reconstruct_increase(self, state_dict_lora: Dict[str, Any], lora_config_dict
config_iter = iter(self.lora_config_dict.items())
lora_A, lora_B, layer_prefix = None, None, None
for k, v in state_dict_lora.items():
- if k.rpartition('.')[-1] == 'lora_A':
+ if k.rpartition(".")[-1] == "lora_A":
lora_A = v
- layer_prefix = k.rpartition('.')[0]
- elif k.rpartition('.')[-1] == 'lora_B':
- assert layer_prefix == k.rpartition('.')[0], "unmatched (lora_A, lora_B) pair"
+ layer_prefix = k.rpartition(".")[0]
+ elif k.rpartition(".")[-1] == "lora_B":
+ assert layer_prefix == k.rpartition(".")[0], "unmatched (lora_A, lora_B) pair"
layer_prefix_2, config = next(config_iter)
assert layer_prefix_2 == layer_prefix, "unmatched (state_dict, config_dict) pair"
lora_B = v
weight_data_increase = self._compute(lora_A, lora_B, config)
- state_dict_increase[layer_prefix + '.weight'] = weight_data_increase
+ state_dict_increase[layer_prefix + ".weight"] = weight_data_increase
lora_A, lora_B, layer_prefix = None, None, None
else:
- raise ValueError('unexpected key')
+ raise ValueError("unexpected key")
return state_dict_increase
def _compute(self, lora_A, lora_B, config=LoRAConfig()):
def T(w):
return w.T if config.fan_in_fan_out else w
+
if config.r > 0:
scaling = config.lora_alpha / config.r
weight_data_increase = T(lora_B @ lora_A) * scaling
@@ -81,21 +80,21 @@ def T(w):
return 0
def load_state_dict_increase(self, model: nn.Module, state_dict_increase: Dict[str, Any]):
- '''
+ """
The final reconstruction step
- '''
+ """
# naive approach
model.load_state_dict({k: v + model.state_dict()[k] for k, v in state_dict_increase.items()}, strict=False)
@staticmethod
def filter_state_dict_lora(state_dict: Dict[str, Any], keep_non_lora=False):
- '''
+ """
if keep_non_lora, also return non_lora state_dict
- '''
+ """
state_dict_lora = OrderedDict()
state_dict_non_lora = OrderedDict()
for k, v in state_dict.items():
- if 'lora_A' in k or 'lora_B' in k:
+ if "lora_A" in k or "lora_B" in k:
state_dict_lora[k] = v
elif keep_non_lora:
state_dict_non_lora[k] = v
@@ -106,17 +105,19 @@ def filter_state_dict_lora(state_dict: Dict[str, Any], keep_non_lora=False):
@staticmethod
def extract_lora_config(model: nn.Module) -> Dict[str, LoRAConfig]:
- '''
+ """
extract LoraLinear model.
return OrderedDict(): name -> LoRAConfig
- '''
+ """
lora_config_dict = OrderedDict()
for name, child in model.named_modules():
if isinstance(child, LoraLinear):
- lora_config_dict[name] = LoRAConfig(r=child.r,
- lora_alpha=child.lora_alpha,
- lora_dropout=child.lora_dropout,
- fan_in_fan_out=child.fan_in_fan_out)
+ lora_config_dict[name] = LoRAConfig(
+ r=child.r,
+ lora_alpha=child.lora_alpha,
+ lora_dropout=child.lora_dropout,
+ fan_in_fan_out=child.fan_in_fan_out,
+ )
return lora_config_dict
diff --git a/applications/Chat/coati/ray/utils.py b/applications/Chat/coati/ray/utils.py
index 761186b95ee5..b88140c0e036 100644
--- a/applications/Chat/coati/ray/utils.py
+++ b/applications/Chat/coati/ray/utils.py
@@ -1,6 +1,6 @@
import os
from collections import OrderedDict
-from typing import Any, Callable, Dict, List, Optional
+from typing import Any, Dict
import torch
import torch.distributed as dist
@@ -10,7 +10,7 @@
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
from coati.models.opt import OPTRM, OPTActor, OPTCritic
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
-from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer
+from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer
def is_rank_0() -> bool:
@@ -26,13 +26,13 @@ def get_world_size() -> int:
def get_actor_from_args(model: str, pretrained: str = None, config=None, lora_rank=0):
- if model == 'gpt2':
+ if model == "gpt2":
actor = GPTActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
- elif model == 'bloom':
+ elif model == "bloom":
actor = BLOOMActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
- elif model == 'opt':
+ elif model == "opt":
actor = OPTActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
- elif model == 'llama':
+ elif model == "llama":
actor = LlamaActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
else:
raise ValueError(f'Unsupported actor model "{model}"')
@@ -40,27 +40,27 @@ def get_actor_from_args(model: str, pretrained: str = None, config=None, lora_ra
def get_critic_from_args(model: str, pretrained: str = None, config=None, lora_rank=0):
- if model == 'gpt2':
- critic = GPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
- elif model == 'bloom':
- critic = BLOOMCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
- elif model == 'opt':
- critic = OPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
- elif model == 'llama':
- critic = LlamaCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
+ if model == "gpt2":
+ critic = GPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config)
+ elif model == "bloom":
+ critic = BLOOMCritic(pretrained=pretrained, lora_rank=lora_rank, config=config)
+ elif model == "opt":
+ critic = OPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config)
+ elif model == "llama":
+ critic = LlamaCritic(pretrained=pretrained, lora_rank=lora_rank, config=config)
else:
raise ValueError(f'Unsupported reward model "{model}"')
return critic
def get_reward_model_from_args(model: str, pretrained: str = None, config=None):
- if model == 'gpt2':
+ if model == "gpt2":
reward_model = GPTRM(pretrained=pretrained, config=config)
- elif model == 'bloom':
+ elif model == "bloom":
reward_model = BLOOMRM(pretrained=pretrained, config=config)
- elif model == 'opt':
+ elif model == "opt":
reward_model = OPTRM(pretrained=pretrained, config=config)
- elif model == 'llama':
+ elif model == "llama":
reward_model = LlamaRM(pretrained=pretrained, config=config)
else:
raise ValueError(f'Unsupported reward model "{model}"')
@@ -68,29 +68,29 @@ def get_reward_model_from_args(model: str, pretrained: str = None, config=None):
def get_strategy_from_args(strategy: str):
- if strategy == 'ddp':
+ if strategy == "ddp":
strategy_ = DDPStrategy()
- elif strategy == 'colossalai_gemini':
- strategy_ = GeminiStrategy(placement_policy='cuda', initial_scale=2**5)
- elif strategy == 'colossalai_zero2':
- strategy_ = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
- elif strategy == 'colossalai_gemini_cpu':
- strategy_ = GeminiStrategy(placement_policy='cpu', initial_scale=2**5)
- elif strategy == 'colossalai_zero2_cpu':
- strategy_ = LowLevelZeroStrategy(stage=2, placement_policy='cpu')
+ elif strategy == "colossalai_gemini":
+ strategy_ = GeminiStrategy(placement_policy="static", initial_scale=2**5)
+ elif strategy == "colossalai_zero2":
+ strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
+ elif strategy == "colossalai_gemini_cpu":
+ strategy_ = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5)
+ elif strategy == "colossalai_zero2_cpu":
+ strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
else:
raise ValueError(f'Unsupported strategy "{strategy}"')
return strategy_
def get_tokenizer_from_args(model: str, **kwargs):
- if model == 'gpt2':
- tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
- elif model == 'bloom':
- tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m')
- elif model == 'opt':
+ if model == "gpt2":
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
+ elif model == "bloom":
+ tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
+ elif model == "opt":
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
- elif model == 'llama':
+ elif model == "llama":
pretrain_path = kwargs["pretrain"]
tokenizer = AutoTokenizer.from_pretrained(pretrain_path)
else:
@@ -101,11 +101,11 @@ def get_tokenizer_from_args(model: str, **kwargs):
def set_dist_env(env_info: Dict[str, str]):
- os.environ["RANK"] = env_info['rank']
- os.environ["LOCAL_RANK"] = env_info['local_rank']
- os.environ["WORLD_SIZE"] = env_info['world_size']
- os.environ['MASTER_PORT'] = env_info['master_port']
- os.environ['MASTER_ADDR'] = env_info['master_addr']
+ os.environ["RANK"] = env_info["rank"]
+ os.environ["LOCAL_RANK"] = env_info["local_rank"]
+ os.environ["WORLD_SIZE"] = env_info["world_size"]
+ os.environ["MASTER_PORT"] = env_info["master_port"]
+ os.environ["MASTER_ADDR"] = env_info["master_addr"]
def get_model_numel(model: nn.Module) -> int:
@@ -116,7 +116,7 @@ def get_model_numel(model: nn.Module) -> int:
def get_receivers_per_sender(sender_idx: int, num_senders: int, num_receivers: int, allow_idle_sender: bool) -> list:
target_receivers = []
if num_senders <= num_receivers or allow_idle_sender:
- # a sender will send data to one or more than one receivers
+ # a sender will send data to one or more receivers
# a receiver only has one sender
for i in range(num_receivers):
if i % num_senders == sender_idx:
@@ -128,12 +128,12 @@ def get_receivers_per_sender(sender_idx: int, num_senders: int, num_receivers: i
return target_receivers
-def state_dict_to(state_dict: Dict[str, Any],
- dtype: torch.dtype = torch.float16,
- device: torch.device = torch.device('cpu')):
- '''
- keep state_dict intact
- '''
+def state_dict_to(
+ state_dict: Dict[str, Any], dtype: torch.dtype = torch.float16, device: torch.device = torch.device("cpu")
+):
+ """
+ keep state_dict intact
+ """
new_state_dict = OrderedDict()
for k, v in state_dict.items():
new_state_dict[k] = v.to(dtype=dtype, device=device)
diff --git a/applications/Chat/coati/trainer/__init__.py b/applications/Chat/coati/trainer/__init__.py
index 86142361f3ff..4be5d27f93b1 100644
--- a/applications/Chat/coati/trainer/__init__.py
+++ b/applications/Chat/coati/trainer/__init__.py
@@ -3,8 +3,4 @@
from .rm import RewardModelTrainer
from .sft import SFTTrainer
-__all__ = [
- 'SLTrainer', 'OnPolicyTrainer',
- 'RewardModelTrainer', 'SFTTrainer',
- 'PPOTrainer'
-]
+__all__ = ["SLTrainer", "OnPolicyTrainer", "RewardModelTrainer", "SFTTrainer", "PPOTrainer"]
diff --git a/applications/Chat/coati/trainer/base.py b/applications/Chat/coati/trainer/base.py
index 0629c9c00cca..0a41d450d41e 100644
--- a/applications/Chat/coati/trainer/base.py
+++ b/applications/Chat/coati/trainer/base.py
@@ -7,11 +7,10 @@
from coati.experience_buffer import NaiveExperienceBuffer
from coati.experience_maker import Experience
from torch.optim import Optimizer
-from torch.utils.data import DataLoader
from .callbacks import Callback
from .strategies import Strategy
-from .utils import CycledDataLoader, is_rank_0
+from .utils import is_rank_0
class SLTrainer(ABC):
@@ -47,11 +46,11 @@ def _eval(self, epoch):
raise NotImplementedError()
def _before_fit(self):
- self.no_epoch_bar = False
+ raise NotImplementedError()
def fit(self, *args, **kwargs):
self._before_fit(*args, **kwargs)
- for epoch in tqdm.trange(self.max_epochs, desc="Epochs", disable=not is_rank_0() or self.no_epoch_bar):
+ for epoch in tqdm.trange(self.max_epochs, desc="Epochs", disable=not is_rank_0()):
self._train(epoch)
self._eval(epoch)
@@ -68,12 +67,14 @@ class OnPolicyTrainer(ABC):
callbacks (List[Callback], defaults to []): the callbacks to call during training process
"""
- def __init__(self,
- strategy: Strategy,
- data_buffer: NaiveExperienceBuffer,
- sample_buffer: bool,
- dataloader_pin_memory: bool,
- callbacks: List[Callback] = []) -> None:
+ def __init__(
+ self,
+ strategy: Strategy,
+ data_buffer: NaiveExperienceBuffer,
+ sample_buffer: bool,
+ dataloader_pin_memory: bool,
+ callbacks: List[Callback] = [],
+ ) -> None:
super().__init__()
self.strategy = strategy
self.data_buffer = data_buffer
@@ -121,9 +122,9 @@ def _on_learn_batch_start(self) -> None:
for callback in self.callbacks:
callback.on_learn_batch_start()
- def _on_learn_batch_end(self, metrics: dict, experience: Experience) -> None:
+ def _on_learn_batch_end(self, experience: Experience) -> None:
for callback in self.callbacks:
- callback.on_learn_batch_end(metrics, experience)
+ callback.on_learn_batch_end(experience)
@abstractmethod
def _make_experience(self, collect_step: int):
@@ -151,27 +152,26 @@ def _update_phase(self, update_step: int):
self._learn(update_step)
self._on_learn_epoch_end(update_step)
+ def _before_fit(self, *args, **kwargs):
+ raise NotImplementedError()
+
def fit(
self,
- prompt_dataloader: DataLoader,
- pretrain_dataloader: DataLoader,
num_episodes: int,
num_collect_steps: int,
num_update_steps: int,
+ *args,
+ **kwargs,
):
"""
The main training loop of on-policy rl trainers.
Args:
- prompt_dataloader (DataLoader): the dataloader to use for prompt data
- pretrain_dataloader (DataLoader): the dataloader to use for pretrain data
num_episodes (int): the number of episodes to train
num_collect_steps (int): the number of collect steps per episode
num_update_steps (int): the number of update steps per episode
"""
- self.prompt_dataloader = CycledDataLoader(prompt_dataloader)
- self.pretrain_dataloader = CycledDataLoader(pretrain_dataloader)
-
+ self._before_fit(*args, **kwargs)
with self._fit_ctx():
for episode in tqdm.trange(num_episodes, desc="Episodes", disable=not is_rank_0()):
with self._episode_ctx(episode):
diff --git a/applications/Chat/coati/trainer/callbacks/__init__.py b/applications/Chat/coati/trainer/callbacks/__init__.py
index 9ed0ee6f7640..29c8c4f00a5c 100644
--- a/applications/Chat/coati/trainer/callbacks/__init__.py
+++ b/applications/Chat/coati/trainer/callbacks/__init__.py
@@ -2,4 +2,4 @@
from .performance_evaluator import PerformanceEvaluator
from .save_checkpoint import SaveCheckpoint
-__all__ = ['Callback', 'PerformanceEvaluator', 'SaveCheckpoint']
+__all__ = ["Callback", "PerformanceEvaluator", "SaveCheckpoint"]
diff --git a/applications/Chat/coati/trainer/callbacks/base.py b/applications/Chat/coati/trainer/callbacks/base.py
index f5616048855b..c6e30f04885c 100644
--- a/applications/Chat/coati/trainer/callbacks/base.py
+++ b/applications/Chat/coati/trainer/callbacks/base.py
@@ -5,7 +5,7 @@
class Callback(ABC):
"""
- Base callback class. It defines the interface for callbacks.
+ Base callback class. It defines the interface for callbacks.
"""
def on_fit_start(self) -> None:
@@ -35,5 +35,5 @@ def on_learn_epoch_end(self, epoch: int) -> None:
def on_learn_batch_start(self) -> None:
pass
- def on_learn_batch_end(self, metrics: dict, experience: Experience) -> None:
+ def on_learn_batch_end(self, experience: Experience) -> None:
pass
diff --git a/applications/Chat/coati/trainer/callbacks/performance_evaluator.py b/applications/Chat/coati/trainer/callbacks/performance_evaluator.py
index 9b44dafa7eaa..b286c766c263 100644
--- a/applications/Chat/coati/trainer/callbacks/performance_evaluator.py
+++ b/applications/Chat/coati/trainer/callbacks/performance_evaluator.py
@@ -21,9 +21,9 @@ def print_rank_0(*args, **kwargs) -> None:
def divide(x: float, y: float) -> float:
if y == 0:
- return float('inf')
- elif y == float('inf'):
- return float('nan')
+ return float("inf")
+ elif y == float("inf"):
+ return float("nan")
return x / y
@@ -38,10 +38,9 @@ def all_reduce_mean(x: float, world_size: int) -> float:
class Timer:
-
def __init__(self) -> None:
self.start_time: Optional[float] = None
- self.duration: float = 0.
+ self.duration: float = 0.0
def start(self) -> None:
self.start_time = time()
@@ -52,7 +51,7 @@ def end(self) -> None:
self.start_time = None
def reset(self) -> None:
- self.duration = 0.
+ self.duration = 0.0
class PerformanceEvaluator(Callback):
@@ -67,13 +66,15 @@ class PerformanceEvaluator(Callback):
ignore_episodes: The number of episodes to ignore when calculating the performance.
"""
- def __init__(self,
- actor_num_params: int,
- critic_num_params: int,
- initial_model_num_params: int,
- reward_model_num_params: int,
- enable_grad_checkpoint: bool = False,
- ignore_episodes: int = 0) -> None:
+ def __init__(
+ self,
+ actor_num_params: int,
+ critic_num_params: int,
+ initial_model_num_params: int,
+ reward_model_num_params: int,
+ enable_grad_checkpoint: bool = False,
+ ignore_episodes: int = 0,
+ ) -> None:
super().__init__()
self.world_size = get_world_size()
self.actor_num_params = actor_num_params
@@ -136,7 +137,7 @@ def on_learn_batch_start(self) -> None:
return
self.learn_timer.start()
- def on_learn_batch_end(self, metrics: dict, experience: Experience) -> None:
+ def on_learn_batch_end(self, experience: Experience) -> None:
if self.disable:
return
self.learn_timer.end()
@@ -155,8 +156,9 @@ def on_fit_end(self) -> None:
avg_learn_duration = all_reduce_mean(self.learn_timer.duration, self.world_size)
avg_overall_duration = all_reduce_mean(self.overall_timer.duration, self.world_size)
- avg_make_experience_throughput = self.make_experience_num_samples * \
- self.world_size / (avg_make_experience_duration + 1e-12)
+ avg_make_experience_throughput = (
+ self.make_experience_num_samples * self.world_size / (avg_make_experience_duration + 1e-12)
+ )
avg_make_experience_tflops = self.make_experience_flop / 1e12 / (avg_make_experience_duration + 1e-12)
avg_learn_throughput = self.learn_num_samples * self.world_size / (avg_learn_duration + 1e-12)
@@ -171,13 +173,11 @@ def on_fit_end(self) -> None:
learn_time_per_sample = divide(avg_learn_duration, num_effective_samples)
print_rank_0(
- f'Performance summary:\n'
- + f'Generate {self.make_experience_num_samples * self.world_size} samples, throughput: {avg_make_experience_throughput:.2f} samples/s, TFLOPS per GPU: {avg_make_experience_tflops:.2f}\n'
-
- + f'Train {self.learn_num_samples * self.world_size} samples, throughput: {avg_learn_throughput:.2f} samples/s, TFLOPS per GPU: {avg_learn_tflops:.2f}\n'
- + f'Overall throughput: {avg_overall_throughput:.2f} samples/s\n'
- + f'Overall time per sample: {overall_time_per_sample:.2f} s\n'
- + f'Make experience time per sample: {make_experience_time_per_sample:.2f} s, {make_experience_time_per_sample/overall_time_per_sample*100:.2f}%\n'
-
- + f'Learn time per sample: {learn_time_per_sample:.2f} s, {learn_time_per_sample/overall_time_per_sample*100:.2f}%'
+ f"Performance summary:\n"
+ + f"Generate {self.make_experience_num_samples * self.world_size} samples, throughput: {avg_make_experience_throughput:.2f} samples/s, TFLOPS per GPU: {avg_make_experience_tflops:.2f}\n"
+ + f"Train {self.learn_num_samples * self.world_size} samples, throughput: {avg_learn_throughput:.2f} samples/s, TFLOPS per GPU: {avg_learn_tflops:.2f}\n"
+ + f"Overall throughput: {avg_overall_throughput:.2f} samples/s\n"
+ + f"Overall time per sample: {overall_time_per_sample:.2f} s\n"
+ + f"Make experience time per sample: {make_experience_time_per_sample:.2f} s, {make_experience_time_per_sample/overall_time_per_sample*100:.2f}%\n"
+ + f"Learn time per sample: {learn_time_per_sample:.2f} s, {learn_time_per_sample/overall_time_per_sample*100:.2f}%"
)
diff --git a/applications/Chat/coati/trainer/callbacks/save_checkpoint.py b/applications/Chat/coati/trainer/callbacks/save_checkpoint.py
index f0d77a191a88..0d70b6c53073 100644
--- a/applications/Chat/coati/trainer/callbacks/save_checkpoint.py
+++ b/applications/Chat/coati/trainer/callbacks/save_checkpoint.py
@@ -36,34 +36,35 @@ class SaveCheckpoint(Callback):
"""
- def __init__(self,
- path: str,
- interval: int,
- strategy: Strategy,
- actor: nn.Module = None,
- critic: nn.Module = None,
- actor_optim: Optimizer = None,
- critic_optim: Optimizer = None) -> None:
+ def __init__(
+ self,
+ path: str,
+ interval: int,
+ strategy: Strategy,
+ actor: nn.Module = None,
+ critic: nn.Module = None,
+ actor_optim: Optimizer = None,
+ critic_optim: Optimizer = None,
+ ) -> None:
super().__init__()
- self.path = os.path.join(path, 'checkpoint')
+ self.path = os.path.join(path, "checkpoint")
self.interval = interval
self.strategy = strategy
- self.model_dict = {'actor': [actor, actor_optim], 'critic': [critic, critic_optim]}
+ self.model_dict = {"actor": [actor, actor_optim], "critic": [critic, critic_optim]}
def on_episode_end(self, episode: int) -> None:
if (episode + 1) % self.interval != 0:
return
- base_path = os.path.join(self.path, f'episode_{episode}')
+ base_path = os.path.join(self.path, f"episode_{episode}")
if not os.path.exists(base_path):
os.makedirs(base_path)
for model in self.model_dict.keys():
-
# save model
if self.model_dict[model][0] is None:
# saving only optimizer states is meaningless, so it would be skipped
continue
- model_path = os.path.join(base_path, f'{model}.pt')
+ model_path = os.path.join(base_path, f"{model}.pt")
self.strategy.save_model(model=self.model_dict[model][0], path=model_path, only_rank0=True)
# save optimizer
@@ -71,5 +72,5 @@ def on_episode_end(self, episode: int) -> None:
continue
only_rank0 = not isinstance(self.strategy, (LowLevelZeroStrategy, GeminiStrategy))
rank = 0 if is_rank_0() else dist.get_rank()
- optim_path = os.path.join(base_path, f'{model}-optim-rank-{rank}.pt')
+ optim_path = os.path.join(base_path, f"{model}-optim-rank-{rank}.pt")
self.strategy.save_optimizer(optimizer=self.model_dict[model][1], path=optim_path, only_rank0=only_rank0)
diff --git a/applications/Chat/coati/trainer/ppo.py b/applications/Chat/coati/trainer/ppo.py
index ef625a1c1b3d..d6966689885e 100644
--- a/applications/Chat/coati/trainer/ppo.py
+++ b/applications/Chat/coati/trainer/ppo.py
@@ -1,34 +1,33 @@
-from typing import Dict, List
+from typing import Dict, List, Optional
-import torch.nn as nn
from coati.experience_buffer import NaiveExperienceBuffer
from coati.experience_maker import Experience, NaiveExperienceMaker
-from coati.models.base import Actor, Critic, get_base_model
+from coati.models.base import Actor, Critic, RewardModel, get_base_model
from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss
from coati.models.utils import calc_action_log_probs
-from torch import Tensor
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm
+from transformers import PreTrainedTokenizerBase
from colossalai.utils import get_current_device
from .base import OnPolicyTrainer
from .callbacks import Callback
from .strategies import GeminiStrategy, Strategy
-from .utils import is_rank_0, to_device
+from .utils import CycledDataLoader, is_rank_0, to_device
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> Dict:
- unwrapper_model = strategy.unwrap_model(actor)
- hf_model = get_base_model(unwrapper_model)
+ unwrapped_model = strategy.unwrap_model(actor)
+ hf_model = get_base_model(unwrapped_model)
new_kwargs = {**generate_kwargs}
# use huggingface models method directly
- if 'prepare_inputs_fn' not in generate_kwargs and hasattr(hf_model, 'prepare_inputs_for_generation'):
- new_kwargs['prepare_inputs_fn'] = hf_model.prepare_inputs_for_generation
+ if "prepare_inputs_fn" not in generate_kwargs and hasattr(hf_model, "prepare_inputs_for_generation"):
+ new_kwargs["prepare_inputs_fn"] = hf_model.prepare_inputs_for_generation
- if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(hf_model, '_update_model_kwargs_for_generation'):
- new_kwargs['update_model_kwargs_fn'] = hf_model._update_model_kwargs_for_generation
+ if "update_model_kwargs_fn" not in generate_kwargs and hasattr(hf_model, "_update_model_kwargs_for_generation"):
+ new_kwargs["update_model_kwargs_fn"] = hf_model._update_model_kwargs_for_generation
return new_kwargs
@@ -41,7 +40,7 @@ class PPOTrainer(OnPolicyTrainer):
strategy (Strategy): the strategy to use for training
actor (Actor): the actor model in ppo algorithm
critic (Critic): the critic model in ppo algorithm
- reward_model (nn.Module): the reward model in rlhf algorithm to make reward of sentences
+ reward_model (RewardModel): the reward model in rlhf algorithm to make reward of sentences
initial_model (Actor): the initial model in rlhf algorithm to generate reference logics to limit the update of actor
actor_optim (Optimizer): the optimizer to use for actor model
critic_optim (Optimizer): the optimizer to use for critic model
@@ -60,45 +59,42 @@ class PPOTrainer(OnPolicyTrainer):
generate_kwargs (dict, optional): the kwargs to use while model generating
"""
- def __init__(self,
- strategy: Strategy,
- actor: Actor,
- critic: Critic,
- reward_model: nn.Module,
- initial_model: Actor,
- actor_optim: Optimizer,
- critic_optim: Optimizer,
- kl_coef: float = 0.1,
- ptx_coef: float = 0.9,
- train_batch_size: int = 8,
- buffer_limit: int = 0,
- buffer_cpu_offload: bool = True,
- eps_clip: float = 0.2,
- vf_coef: float = 1.0,
- value_clip: float = 0.4,
- sample_buffer: bool = False,
- dataloader_pin_memory: bool = True,
- offload_inference_models: bool = True,
- callbacks: List[Callback] = [],
- **generate_kwargs
- ) -> None:
+ def __init__(
+ self,
+ strategy: Strategy,
+ actor: Actor,
+ critic: Critic,
+ reward_model: RewardModel,
+ initial_model: Actor,
+ actor_optim: Optimizer,
+ critic_optim: Optimizer,
+ tokenizer: PreTrainedTokenizerBase,
+ kl_coef: float = 0.1,
+ ptx_coef: float = 0.9,
+ train_batch_size: int = 8,
+ buffer_limit: int = 0,
+ buffer_cpu_offload: bool = True,
+ eps_clip: float = 0.2,
+ vf_coef: float = 1.0,
+ value_clip: float = 0.4,
+ sample_buffer: bool = False,
+ dataloader_pin_memory: bool = True,
+ offload_inference_models: bool = True,
+ callbacks: List[Callback] = [],
+ **generate_kwargs,
+ ) -> None:
if isinstance(strategy, GeminiStrategy):
- assert not offload_inference_models, \
- "GeminiPlugin is not compatible with manual model.to('cpu')"
+ assert not offload_inference_models, "GeminiPlugin is not compatible with manual model.to('cpu')"
data_buffer = NaiveExperienceBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
- super().__init__(
- strategy, data_buffer,
- sample_buffer, dataloader_pin_memory,
- callbacks
- )
+ super().__init__(strategy, data_buffer, sample_buffer, dataloader_pin_memory, callbacks)
self.generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
- self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef)
- self.offload_inference_models = offload_inference_models
+ self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, tokenizer, kl_coef)
self.actor = actor
self.critic = critic
+ self.tokenizer = tokenizer
self.actor_loss_fn = PolicyLoss(eps_clip)
self.critic_loss_fn = ValueLoss(value_clip)
@@ -108,84 +104,99 @@ def __init__(self,
self.actor_optim = actor_optim
self.critic_optim = critic_optim
+ self.offload_inference_models = offload_inference_models
self.device = get_current_device()
+ def _before_fit(
+ self,
+ prompt_dataloader: DataLoader,
+ pretrain_dataloader: DataLoader,
+ log_dir: Optional[str] = None,
+ use_wandb: bool = False,
+ ):
+ """
+ Args:
+ prompt_dataloader (DataLoader): the dataloader to use for prompt data
+ pretrain_dataloader (DataLoader): the dataloader to use for pretrain data
+ """
+ self.prompt_dataloader = CycledDataLoader(prompt_dataloader)
+ self.pretrain_dataloader = CycledDataLoader(pretrain_dataloader)
+
+ self.writer = None
+ if use_wandb and is_rank_0():
+ assert log_dir is not None, "log_dir must be provided when use_wandb is True"
+ import wandb
+
+ wandb.init(project="Coati-ppo", sync_tensorboard=True)
+ if log_dir is not None and is_rank_0():
+ import os
+ import time
+
+ from torch.utils.tensorboard import SummaryWriter
+
+ log_dir = os.path.join(log_dir, "ppo")
+ log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
+ self.writer = SummaryWriter(log_dir=log_dir)
+
def _make_experience(self, collect_step: int) -> Experience:
prompts = self.prompt_dataloader.next()
if self.offload_inference_models:
# TODO(ver217): this may be controlled by strategy if they are prepared by strategy
self.experience_maker.initial_model.to(self.device)
self.experience_maker.reward_model.to(self.device)
- if isinstance(prompts, Tensor):
- return self.experience_maker.make_experience(prompts, **self.generate_kwargs)
- elif isinstance(prompts, dict):
- return self.experience_maker.make_experience(**prompts, **self.generate_kwargs)
- else:
- raise ValueError(f'Unsupported input type "{type(prompts)}"')
+ assert isinstance(prompts, dict), f'Unsupported input type "{type(prompts)}"'
+ return self.experience_maker.make_experience(**prompts, **self.generate_kwargs)
- def _training_step(self, experience: Experience) -> Dict[str, float]:
+ def _training_step(self, experience: Experience):
self.actor.train()
self.critic.train()
# policy loss
- num_actions = experience.action_mask.size(1)
- actor_output = self.actor(experience.sequences, attention_mask=experience.attention_mask)
- action_log_probs = calc_action_log_probs(actor_output, experience.sequences, num_actions)
- actor_loss = self.actor_loss_fn(action_log_probs,
- experience.action_log_probs,
- experience.advantages,
- action_mask=experience.action_mask)
+ num_actions = experience.action_log_probs.size(1)
+ actor_logits = self.actor(experience.sequences, experience.attention_mask)["logits"]
+ action_log_probs = calc_action_log_probs(actor_logits, experience.sequences, num_actions)
+ actor_loss = self.actor_loss_fn(
+ action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask
+ )
+ actor_loss = (1 - self.ptx_coef) * actor_loss
+ self.strategy.backward(actor_loss, self.actor, self.actor_optim)
# ptx loss
if self.ptx_coef != 0:
batch = self.pretrain_dataloader.next()
batch = to_device(batch, self.device)
- ptx_log_probs = self.actor(batch['input_ids'],
- attention_mask=batch['attention_mask'])['logits']
- ptx_loss = self.ptx_loss_fn(ptx_log_probs, batch['labels'])
- actor_loss = ptx_loss * self.ptx_coef + actor_loss * (1 - self.ptx_coef)
+ ptx_log_probs = self.actor(batch["input_ids"], batch["attention_mask"])["logits"]
+ ptx_loss = self.ptx_coef * self.ptx_loss_fn(ptx_log_probs, batch["labels"])
+ self.strategy.backward(ptx_loss, self.actor, self.actor_optim)
- self.strategy.backward(actor_loss, self.actor, self.actor_optim)
self.strategy.optimizer_step(self.actor_optim)
self.actor_optim.zero_grad()
# value loss
- values = self.critic(experience.sequences,
- action_mask=experience.action_mask,
- attention_mask=experience.attention_mask)
- critic_loss = self.critic_loss_fn(values,
- experience.values,
- experience.reward,
- action_mask=experience.action_mask)
+ values = self.critic(experience.sequences, attention_mask=experience.attention_mask)
+ critic_loss = self.critic_loss_fn(values, experience.values, experience.reward)
critic_loss = critic_loss * self.vf_coef
self.strategy.backward(critic_loss, self.critic, self.critic_optim)
self.strategy.optimizer_step(self.critic_optim)
self.critic_optim.zero_grad()
- return {'reward': experience.reward.mean().item()}
-
def _learn(self, update_step: int):
if self.offload_inference_models:
- self.experience_maker.initial_model.to('cpu')
- self.experience_maker.reward_model.to('cpu')
+ self.experience_maker.initial_model.to("cpu")
+ self.experience_maker.reward_model.to("cpu")
# buffer may be empty at first, we should rebuild at each training
if self.sample_buffer:
experience = self.data_buffer.sample()
self._on_learn_batch_start()
experience.to_device(self.device)
- metrics = self._training_step(experience)
- self._on_learn_batch_end(metrics, experience)
+ self._training_step(experience)
+ self._on_learn_batch_end(experience)
else:
if isinstance(self.dataloader.sampler, DistributedSampler):
self.dataloader.sampler.set_epoch(update_step)
- pbar = tqdm(
- self.dataloader,
- desc=f'Train epoch [{update_step + 1}]',
- disable=not is_rank_0()
- )
+ pbar = tqdm(self.dataloader, desc=f"Train epoch [{update_step + 1}]", disable=not is_rank_0())
for experience in pbar:
self._on_learn_batch_start()
experience.to_device(self.device)
- metrics = self._training_step(experience)
- self._on_learn_batch_end(metrics, experience)
- pbar.set_postfix(metrics)
+ self._training_step(experience)
+ self._on_learn_batch_end(experience)
diff --git a/applications/Chat/coati/trainer/rm.py b/applications/Chat/coati/trainer/rm.py
index 54a5d0f40dea..d7f8c21a5a3d 100644
--- a/applications/Chat/coati/trainer/rm.py
+++ b/applications/Chat/coati/trainer/rm.py
@@ -1,7 +1,5 @@
-from datetime import datetime
-from typing import Callable
+from typing import Callable, Optional
-import pandas as pd
import torch
import tqdm
from torch.optim import Optimizer
@@ -40,10 +38,12 @@ def __init__(
self.loss_fn = loss_fn
self.scheduler = lr_scheduler
+ self.num_train_step = 0
+
def _eval(self, epoch):
if self.eval_dataloader is not None:
self.model.eval()
- dist, on, cnt = 0, 0, 0
+ dist, num_correct, num_samples = 0, 0, 0
with torch.no_grad():
for chosen_ids, c_mask, reject_ids, r_mask in self.eval_dataloader:
chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
@@ -52,30 +52,21 @@ def _eval(self, epoch):
r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
reject_reward = self.model(reject_ids, attention_mask=r_mask)
- for i in range(len(chosen_reward)):
- cnt += 1
- if chosen_reward[i] > reject_reward[i]:
- on += 1
+ num_samples += chosen_ids.size(0)
+ num_correct += (chosen_reward > reject_reward).sum().item()
dist += (chosen_reward - reject_reward).mean().item()
self.dist = dist / len(self.eval_dataloader)
- self.acc = on / cnt
+ self.acc = num_correct / num_samples
- if is_rank_0():
- log = pd.DataFrame(
- [[(epoch + 1) * len(self.train_dataloader),
- self.loss.item(), self.dist, self.acc]],
- columns=['step', 'loss', 'dist', 'acc']
- )
- log.to_csv('log.csv', mode='a', header=False, index=False)
+ if self.writer:
+ self.writer.add_scalar("eval/dist", self.dist, epoch)
+ self.writer.add_scalar("eval/acc", self.acc, epoch)
def _train(self, epoch):
self.model.train()
step_bar = tqdm.trange(
- len(self.train_dataloader),
- desc='Train step of epoch %d' % epoch,
- disable=not is_rank_0()
+ len(self.train_dataloader), desc=f"Epoch {epoch + 1}/{self.max_epochs}", disable=not is_rank_0()
)
- cnt = 0
for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader:
chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
@@ -83,29 +74,50 @@ def _train(self, epoch):
r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
reject_reward = self.model(reject_ids, attention_mask=r_mask)
- self.loss = self.loss_fn(chosen_reward, reject_reward)
- self.strategy.backward(self.loss, self.model, self.optimizer)
+ loss = self.loss_fn(chosen_reward, reject_reward)
+ self.strategy.backward(loss, self.model, self.optimizer)
self.strategy.optimizer_step(self.optimizer)
self.optimizer.zero_grad()
- cnt += 1
- if cnt % 100 == 0:
+ if self.writer:
+ self.writer.add_scalar("train/loss", loss.item(), self.num_train_step)
+ self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
+ self.writer.add_scalar("train/dist", (chosen_reward - reject_reward).mean().item(), self.num_train_step)
+ self.writer.add_scalar(
+ "train/acc", (chosen_reward > reject_reward).float().mean().item(), self.num_train_step
+ )
+ self.num_train_step += 1
+ if self.num_train_step % 100 == 0:
self.scheduler.step()
step_bar.update()
step_bar.close()
- def _before_fit(self,
- train_dataloader: DataLoader,
- valid_dataloader: DataLoader,
- eval_dataloader: DataLoader):
+ def _before_fit(
+ self,
+ train_dataloader: DataLoader,
+ eval_dataloader: DataLoader,
+ log_dir: Optional[str] = None,
+ use_wandb: bool = False,
+ ):
"""
Args:
train_dataloader (DataLoader): the dataloader to use for training
- valid_dataloader (DataLoader): the dataloader to use for validation
eval_dataloader (DataLoader): the dataloader to use for evaluation
"""
- super()._before_fit()
- self.datetime = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
-
self.train_dataloader = train_dataloader
- self.valid_dataloader = valid_dataloader
self.eval_dataloader = eval_dataloader
+
+ self.writer = None
+ if use_wandb and is_rank_0():
+ assert log_dir is not None, "log_dir must be provided when use_wandb is True"
+ import wandb
+
+ wandb.init(project="Coati-rm", sync_tensorboard=True)
+ if log_dir is not None and is_rank_0():
+ import os
+ import time
+
+ from torch.utils.tensorboard import SummaryWriter
+
+ log_dir = os.path.join(log_dir, "rm")
+ log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
+ self.writer = SummaryWriter(log_dir=log_dir)
diff --git a/applications/Chat/coati/trainer/sft.py b/applications/Chat/coati/trainer/sft.py
index e4d0a970740d..7d0eeec897e5 100644
--- a/applications/Chat/coati/trainer/sft.py
+++ b/applications/Chat/coati/trainer/sft.py
@@ -1,10 +1,8 @@
-import time
from typing import Optional
import torch
import torch.distributed as dist
import tqdm
-import wandb
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
@@ -39,48 +37,43 @@ def __init__(
accumulation_steps: int = 8,
) -> None:
if accumulation_steps > 1:
- assert not isinstance(strategy, GeminiStrategy), \
- "Accumulation steps are not supported in stage 3 of ColossalAI"
+ assert not isinstance(
+ strategy, GeminiStrategy
+ ), "Accumulation steps are not supported in stage 3 of ColossalAI"
super().__init__(strategy, max_epochs, model, optim)
self.accumulation_steps = accumulation_steps
self.scheduler = lr_scheduler
+ self.num_train_step = 0
+ self.num_eval_step = 0
+
def _train(self, epoch: int):
self.model.train()
- for batch_id, batch in enumerate(self.train_dataloader):
-
+ step_bar = tqdm.trange(
+ len(self.train_dataloader) // self.accumulation_steps,
+ desc=f"Epoch {epoch + 1}/{self.max_epochs}",
+ disable=not is_rank_0(),
+ )
+ for i, batch in enumerate(self.train_dataloader):
batch = to_device(batch, torch.cuda.current_device())
- if "attention_mask" in batch:
- outputs = self.model(batch["input_ids"],
- attention_mask=batch["attention_mask"],
- labels=batch["labels"])
- else:
- outputs = self.model(batch["input_ids"],
- labels=batch["labels"])
-
- loss = outputs.loss
- loss = loss / self.accumulation_steps
-
- self.strategy.backward(loss, self.model, self.optimizer)
-
+ outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
+ loss = outputs.loss / self.accumulation_steps
self.total_loss += loss.item()
-
+ self.strategy.backward(loss, self.model, self.optimizer)
# gradient accumulation
- if (batch_id + 1) % self.accumulation_steps == 0:
+ if (i + 1) % self.accumulation_steps == 0:
self.strategy.optimizer_step(self.optimizer)
self.optimizer.zero_grad()
self.scheduler.step()
- if is_rank_0() and self.use_wandb:
- wandb.log({
- "loss": self.total_loss / self.accumulation_steps,
- "lr": self.scheduler.get_last_lr()[0],
- "epoch": epoch,
- "batch_id": batch_id
- })
+ if self.writer:
+ self.writer.add_scalar("train/loss", self.total_loss, self.num_train_step)
+ self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step)
+ self.num_train_step += 1
self.total_loss = 0
- self.step_bar.update()
+ step_bar.update()
+ step_bar.close()
def _eval(self, epoch: int):
if self.eval_dataloader is not None:
@@ -89,23 +82,26 @@ def _eval(self, epoch: int):
loss_sum, num_seen = 0, 0
for batch in self.eval_dataloader:
batch = to_device(batch, torch.cuda.current_device())
- outputs = self.model(batch["input_ids"],
- attention_mask=batch["attention_mask"],
- labels=batch["labels"])
- loss = outputs.loss
-
- loss_sum += loss.item()
+ outputs = self.model(
+ batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]
+ )
+ loss_sum += outputs.loss.item()
num_seen += batch["input_ids"].size(0)
-
loss_mean = loss_sum / num_seen
if dist.get_rank() == 0:
- self.logger.info(f'Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}')
+ self.logger.info(f"Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}")
+ if self.writer:
+ self.writer.add_scalar("eval/loss", loss_mean, self.num_eval_step)
+ self.num_eval_step += 1
- def _before_fit(self,
- train_dataloader: DataLoader,
- eval_dataloader: Optional[DataLoader] = None,
- logger: Optional[DistributedLogger] = None,
- use_wandb: bool = False):
+ def _before_fit(
+ self,
+ train_dataloader: DataLoader,
+ eval_dataloader: Optional[DataLoader] = None,
+ logger: Optional[DistributedLogger] = None,
+ log_dir: Optional[str] = None,
+ use_wandb: bool = False,
+ ):
"""
Args:
train_dataloader: the dataloader to use for training
@@ -115,15 +111,20 @@ def _before_fit(self,
self.eval_dataloader = eval_dataloader
self.logger = logger
- self.use_wandb = use_wandb
- if use_wandb:
- wandb.init(project="Coati", name=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
- wandb.watch(self.model)
+ self.writer = None
+ if use_wandb and is_rank_0():
+ assert log_dir is not None, "log_dir must be provided when use_wandb is True"
+ import wandb
+
+ wandb.init(project="Coati-sft", sync_tensorboard=True)
+ if log_dir is not None and is_rank_0():
+ import os
+ import time
+
+ from torch.utils.tensorboard import SummaryWriter
+
+ log_dir = os.path.join(log_dir, "sft")
+ log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
+ self.writer = SummaryWriter(log_dir=log_dir)
self.total_loss = 0
- self.no_epoch_bar = True
- self.step_bar = tqdm.trange(
- len(self.train_dataloader) // self.accumulation_steps * self.max_epochs,
- desc=f'steps',
- disable=not is_rank_0()
- )
diff --git a/applications/Chat/coati/trainer/strategies/__init__.py b/applications/Chat/coati/trainer/strategies/__init__.py
index b49a2c742db3..521dcb5855b1 100644
--- a/applications/Chat/coati/trainer/strategies/__init__.py
+++ b/applications/Chat/coati/trainer/strategies/__init__.py
@@ -2,7 +2,4 @@
from .colossalai import GeminiStrategy, LowLevelZeroStrategy
from .ddp import DDPStrategy
-__all__ = [
- 'Strategy', 'DDPStrategy',
- 'LowLevelZeroStrategy', 'GeminiStrategy'
-]
+__all__ = ["Strategy", "DDPStrategy", "LowLevelZeroStrategy", "GeminiStrategy"]
diff --git a/applications/Chat/coati/trainer/strategies/base.py b/applications/Chat/coati/trainer/strategies/base.py
index c20b2b16e396..a78716216ae0 100644
--- a/applications/Chat/coati/trainer/strategies/base.py
+++ b/applications/Chat/coati/trainer/strategies/base.py
@@ -19,7 +19,7 @@
class Strategy(ABC):
"""
- Base class for training strategies.
+ Base class for training strategies.
"""
def __init__(self, plugin_initializer: Callable[..., Optional[Plugin]] = lambda: None) -> None:
@@ -83,16 +83,18 @@ def prepare(self, *boost_args: _BoostArgSpec) -> Union[List[_BoostArgSpec], _Boo
rets.append((model, optimizer))
elif isinstance(arg, Dict):
model, optimizer, criterion, dataloader, lr_scheduler = self.booster.boost(**arg)
- boost_result = dict(model=model,
- optimizer=optimizer,
- criterion=criterion,
- dataloader=dataloader,
- lr_scheduler=lr_scheduler)
+ boost_result = dict(
+ model=model,
+ optimizer=optimizer,
+ criterion=criterion,
+ dataloader=dataloader,
+ lr_scheduler=lr_scheduler,
+ )
# remove None values
boost_result = {key: value for key, value in boost_result.items() if value is not None}
rets.append(boost_result)
else:
- raise RuntimeError(f'Type {type(arg)} is not supported')
+ raise RuntimeError(f"Type {type(arg)} is not supported")
return rets[0] if len(rets) == 1 else rets
@@ -108,8 +110,8 @@ def unwrap_model(model: nn.Module) -> nn.Module:
"""
return model
- def save_model(self, model: nn.Module, path: str, only_rank0: bool = True, **kwargs) -> None:
- self.booster.save_model(model, path, shard=not only_rank0, **kwargs)
+ def save_model(self, model: nn.Module, path: str, shard: bool = False, **kwargs) -> None:
+ self.booster.save_model(model, path, shard=shard, **kwargs)
def load_model(self, model: nn.Module, path: str, strict: bool = True) -> None:
self.booster.load_model(model, path, strict)
@@ -125,11 +127,9 @@ def setup_sampler(self, dataset) -> DistributedSampler:
return DistributedSampler(dataset, 1, 0)
@abstractmethod
- def save_pretrained(self,
- model: nn.Module,
- path: str,
- only_rank0: bool = True,
- tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
+ def save_pretrained(
+ self, model: nn.Module, path: str, only_rank0: bool = True, tokenizer: Optional[PreTrainedTokenizerBase] = None
+ ) -> None:
pass
@abstractmethod
diff --git a/applications/Chat/coati/trainer/strategies/colossalai.py b/applications/Chat/coati/trainer/strategies/colossalai.py
index fa55f97ad661..7129edb060ef 100644
--- a/applications/Chat/coati/trainer/strategies/colossalai.py
+++ b/applications/Chat/coati/trainer/strategies/colossalai.py
@@ -1,17 +1,12 @@
import warnings
from typing import Optional
-import torch
-import torch.distributed as dist
import torch.nn as nn
import colossalai
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin
-from colossalai.booster.plugin.gemini_plugin import GeminiModel
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
-from colossalai.tensor import ProcessGroup, ShardSpec
from colossalai.utils import get_current_device
-from colossalai.zero import ColoInitContext
from colossalai.zero.gemini.gemini_ddp import GeminiDDP
from .ddp import DDPStrategy
@@ -42,37 +37,34 @@ class LowLevelZeroStrategy(DDPStrategy):
"""
- def __init__(self,
- stage: int = 2,
- precision: str = 'fp16',
- seed: int = 42,
- placement_policy: str = 'cuda',
- reduce_bucket_size: int = 12 * 1024**2, # only for stage 1&2
- overlap_communication: bool = True, # only for stage 1&2
- initial_scale: float = 2**16,
- growth_factor: float = 2,
- backoff_factor: float = 0.5,
- growth_interval: int = 1000,
- hysteresis: int = 2,
- min_scale: float = 1,
- max_scale: float = 2**32,
- max_norm: float = 0.0,
- norm_type: float = 2.0
- ) -> None:
-
+ def __init__(
+ self,
+ stage: int = 2,
+ precision: str = "fp16",
+ seed: int = 42,
+ placement_policy: str = "cuda",
+ reduce_bucket_size: int = 12 * 1024**2, # only for stage 1&2
+ overlap_communication: bool = True, # only for stage 1&2
+ initial_scale: float = 2**16,
+ growth_factor: float = 2,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 1000,
+ hysteresis: int = 2,
+ min_scale: float = 1,
+ max_scale: float = 2**32,
+ max_norm: float = 0.0,
+ norm_type: float = 2.0,
+ ) -> None:
assert stage in (1, 2), f'Unsupported stage "{stage}"'
- assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"'
- assert precision in ('fp32', 'fp16'), f'Unsupported precision "{precision}"'
+ assert placement_policy in ("cpu", "cuda"), f'Unsupported placement policy "{placement_policy}"'
+ assert precision in ("fp32", "fp16"), f'Unsupported precision "{precision}"'
plugin_initializer = lambda: LowLevelZeroPlugin(
- # zero_config
stage=stage,
precision=precision,
- # zero_optim_config
reduce_bucket_size_in_m=reduce_bucket_size,
overlap_communication=overlap_communication,
- cpu_offload=(placement_policy == 'cpu'),
- # optim_config
+ cpu_offload=(placement_policy == "cpu"),
initial_scale=initial_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
@@ -81,14 +73,15 @@ def __init__(self,
min_scale=min_scale,
max_scale=max_scale,
max_norm=max_norm,
- norm_type=norm_type
+ norm_type=norm_type,
)
super().__init__(seed, plugin_initializer)
def _post_init(self) -> None:
- assert isinstance(self.plugin, LowLevelZeroPlugin), \
- f'{type(self).__name__}\'s plugin is not initialized properly.'
+ assert isinstance(
+ self.plugin, LowLevelZeroPlugin
+ ), f"{type(self).__name__}'s plugin is not initialized properly."
def setup_distributed(self) -> None:
colossalai.launch_from_torch({}, seed=self.seed)
@@ -131,54 +124,55 @@ class GeminiStrategy(DDPStrategy):
"""
- def __init__(self,
- seed: int = 42,
- shard_init: bool = False, # only for stage 3
- placement_policy: str = 'cuda',
- pin_memory: bool = True, # only for stage 3
- force_outputs_fp32: bool = False, # only for stage 3
- search_range_m: int = 32, # only for stage 3
- hidden_dim: Optional[int] = None, # only for stage 3
- min_chunk_size_m: float = 32, # only for stage 3
- gpu_margin_mem_ratio: float = 0.0, # only for stage 3
- initial_scale: float = 2**16,
- growth_factor: float = 2,
- backoff_factor: float = 0.5,
- growth_interval: int = 1000,
- hysteresis: int = 2,
- min_scale: float = 1,
- max_scale: float = 2**32,
- max_norm: float = 0.0,
- norm_type: float = 2.0
- ) -> None:
-
- assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"'
-
+ def __init__(
+ self,
+ seed: int = 42,
+ shard_init: bool = False, # only for stage 3
+ placement_policy: str = "auto",
+ shard_param_frac: float = 1.0, # only for static placement
+ offload_optim_frac: float = 0.0, # only for static placement
+ offload_param_frac: float = 0.0, # only for static placement
+ pin_memory: bool = True, # only for stage 3
+ force_outputs_fp32: bool = False, # only for stage 3
+ search_range_m: int = 32, # only for stage 3
+ hidden_dim: Optional[int] = None, # only for stage 3
+ min_chunk_size_m: float = 32, # only for stage 3
+ gpu_margin_mem_ratio: float = 0.0, # only for stage 3
+ initial_scale: float = 2**16,
+ growth_factor: float = 2,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 1000,
+ hysteresis: int = 2,
+ min_scale: float = 1,
+ max_scale: float = 2**32,
+ max_norm: float = 0.0,
+ norm_type: float = 2.0,
+ ) -> None:
# TODO(ver217): support shard_init when using from_pretrained()
if shard_init:
warnings.warn(
- f'Shard init is not supported model.from_pretrained() yet. '
- 'Please load weights after strategy.prepare()'
+ f"Shard init is not supported model.from_pretrained() yet. "
+ "Please load weights after strategy.prepare()"
)
self.shard_init = shard_init
- warnings.warn(f'Stage 3 only supports fp16. Precision is set to fp16.')
+ warnings.warn(f"Stage 3 only supports fp16. Precision is set to fp16.")
# NOTE: dist should be initialized before calling get_current_device()
plugin_initializer = lambda: GeminiPlugin(
- # gemini_config
- device=get_current_device(),
+ chunk_init_device=get_current_device(),
placement_policy=placement_policy,
- precision='fp16',
+ shard_param_frac=shard_param_frac,
+ offload_optim_frac=offload_optim_frac,
+ offload_param_frac=offload_param_frac,
+ precision="fp16",
pin_memory=pin_memory,
force_outputs_fp32=force_outputs_fp32,
strict_ddp_mode=shard_init,
search_range_m=search_range_m,
hidden_dim=hidden_dim,
min_chunk_size_m=min_chunk_size_m,
- # zero_optim_config
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
- # optim_config
initial_scale=initial_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
@@ -187,29 +181,20 @@ def __init__(self,
min_scale=min_scale,
max_scale=max_scale,
max_norm=max_norm,
- norm_type=norm_type
+ norm_type=norm_type,
)
super().__init__(seed, plugin_initializer)
def _post_init(self) -> None:
- assert isinstance(self.plugin, GeminiPlugin), \
- f'{type(self).__name__}\'s plugin is not initialized properly.'
+ assert isinstance(self.plugin, GeminiPlugin), f"{type(self).__name__}'s plugin is not initialized properly."
def setup_distributed(self) -> None:
colossalai.launch_from_torch({}, seed=self.seed)
def model_init_context(self):
- world_size = dist.get_world_size()
- shard_pg = ProcessGroup(tp_degree=world_size) if self.shard_init else None
- default_dist_spec = ShardSpec([-1], [world_size]) if self.shard_init else None
- return ColoInitContext(device=get_current_device(),
- dtype=torch.half,
- default_pg=shard_pg,
- default_dist_spec=default_dist_spec)
+ return super().model_init_context()
def unwrap_model(self, model: nn.Module) -> nn.Module:
- assert isinstance(model, GeminiModel)
- ddp_model = model.unwrap()
- assert isinstance(ddp_model, GeminiDDP)
- return ddp_model.module
+ assert isinstance(model, GeminiDDP)
+ return model.module
diff --git a/applications/Chat/coati/trainer/strategies/ddp.py b/applications/Chat/coati/trainer/strategies/ddp.py
index a52b0460daa8..f2a44aeb0961 100644
--- a/applications/Chat/coati/trainer/strategies/ddp.py
+++ b/applications/Chat/coati/trainer/strategies/ddp.py
@@ -31,24 +31,21 @@ def get_grad_required_state_dict(model: nn.Module):
class DDPStrategy(Strategy):
"""
- Strategy for distributed training using torch.distributed.
+ Strategy for distributed training using torch.distributed.
"""
- def __init__(self,
- seed: int = 42,
- plugin_initializer: Callable = TorchDDPPlugin
- ) -> None:
+ def __init__(self, seed: int = 42, plugin_initializer: Callable = TorchDDPPlugin) -> None:
self.seed = seed
super().__init__(plugin_initializer)
def _try_init_dist(self, force: bool = False) -> None:
try:
- rank = int(os.environ['RANK'])
- local_rank = int(os.environ['LOCAL_RANK'])
- world_size = int(os.environ['WORLD_SIZE'])
- host = os.environ['MASTER_ADDR']
- port = int(os.environ['MASTER_PORT'])
- dist.init_process_group('nccl', init_method=f'tcp://[{host}]:{port}', world_size=world_size, rank=rank)
+ rank = int(os.environ["RANK"])
+ local_rank = int(os.environ["LOCAL_RANK"])
+ world_size = int(os.environ["WORLD_SIZE"])
+ host = os.environ["MASTER_ADDR"]
+ port = int(os.environ["MASTER_PORT"])
+ dist.init_process_group("nccl", init_method=f"tcp://[{host}]:{port}", world_size=world_size, rank=rank)
torch.cuda.set_device(local_rank)
except KeyError as e:
if force:
@@ -60,8 +57,7 @@ def _try_init_dist(self, force: bool = False) -> None:
raise e
def _post_init(self) -> None:
- assert isinstance(self.plugin, TorchDDPPlugin), \
- f'{type(self).__name__}\'s plugin is not initialized properly.'
+ assert isinstance(self.plugin, TorchDDPPlugin), f"{type(self).__name__}'s plugin is not initialized properly."
def setup_distributed(self) -> None:
self._try_init_dist(force=True)
@@ -73,12 +69,14 @@ def set_seed(self, seed: int) -> None:
torch.manual_seed(seed)
def setup_dataloader(self, data_buffer: ExperienceBuffer, pin_memory: bool = False) -> DataLoader:
- return self.plugin.prepare_dataloader(data_buffer,
- batch_size=data_buffer.sample_batch_size,
- shuffle=True,
- drop_last=True,
- pin_memory=pin_memory,
- collate_fn=data_buffer.collate_fn)
+ return self.plugin.prepare_dataloader(
+ data_buffer,
+ batch_size=data_buffer.sample_batch_size,
+ shuffle=True,
+ drop_last=True,
+ pin_memory=pin_memory,
+ collate_fn=data_buffer.collate_fn,
+ )
def setup_sampler(self, dataset) -> DistributedSampler:
# FIXME(cwher): this is only invoked in train_on_ray, not tested after adapt Boost API.
@@ -88,12 +86,10 @@ def unwrap_model(self, model: nn.Module) -> nn.Module:
assert isinstance(model, TorchDDPModel), "model is not wrapped by TorchDDPModel."
return model.unwrap()
- def save_pretrained(self,
- model: nn.Module,
- path: str,
- only_rank0: bool = True,
- tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
- if not only_rank0 or dist.get_rank() == 0:
+ def save_pretrained(
+ self, model: nn.Module, path: str, shard: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None
+ ) -> None:
+ if dist.get_rank() == 0:
unwrapped_model = self.unwrap_model(model)
assert isinstance(unwrapped_model, (Actor, Critic, RewardModel))
pretrained_model = unwrapped_model.model
@@ -102,35 +98,29 @@ def save_pretrained(self,
pretrained_model.save_pretrained(path, save_function=lambda *args, **kwargs: None)
if tokenizer is not None:
tokenizer.save_pretrained(path)
- model_path = os.path.join(path, "pytorch_model.bin")
- self.save_model(model,
- model_path,
- only_rank0=only_rank0)
- def _replace_keys(model_path: str,
- replace_fn: Callable):
+ model_path = os.path.join(path, "pytorch_model.bin")
+ self.save_model(model, model_path, shard=shard)
+ def _replace_keys(model_path: str, replace_fn: Callable):
state_dict = torch.load(model_path, map_location="cpu")
- state_dict = {
- replace_fn(k): v
- for k, v in state_dict.items()
- }
+ state_dict = {replace_fn(k): v for k, v in state_dict.items()}
torch.save(state_dict, model_path)
-
# FIXME: save_model would add "model." prefix to keys of pytorch_model.bin
# HACK: rename keys of pytorch_model.bin
if dist.get_rank() == 0:
_replace_keys(model_path, lambda k: k.replace("model.", "", 1))
+
def get_model_state_dict_shard(self, model: nn.Module, **config):
# TODO: implement sharding on naive strategy
model = self.unwrap_model(model)
- if 'requires_grad_only' in config and config['requires_grad_only'] == True:
+ if "requires_grad_only" in config and config["requires_grad_only"] == True:
state_dict = get_grad_required_state_dict(model)
else:
state_dict = model.state_dict()
- if 'shard_size' in config:
- shard_size = config['shard_size']
+ if "shard_size" in config:
+ shard_size = config["shard_size"]
accumulate_size = 0
state_dict_shard = OrderedDict()
for name, param in state_dict.items():
diff --git a/applications/Chat/coati/trainer/strategies/sampler.py b/applications/Chat/coati/trainer/strategies/sampler.py
index d726fa640fa2..6e811bef11a5 100644
--- a/applications/Chat/coati/trainer/strategies/sampler.py
+++ b/applications/Chat/coati/trainer/strategies/sampler.py
@@ -4,7 +4,6 @@
class DistributedSampler:
-
def __init__(self, dataset, num_replicas: int, rank: int) -> None:
self.dataset = dataset
self.num_replicas = num_replicas
@@ -12,7 +11,7 @@ def __init__(self, dataset, num_replicas: int, rank: int) -> None:
if len(self.dataset) % self.num_replicas != 0:
self.num_samples = math.ceil(
- (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
+ (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
)
else:
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
@@ -20,10 +19,10 @@ def __init__(self, dataset, num_replicas: int, rank: int) -> None:
self.total_size = self.num_samples * self.num_replicas
indices = list(range(len(self.dataset)))
- indices = indices[:self.total_size]
+ indices = indices[: self.total_size]
assert len(indices) == self.total_size
# subsample
- indices = indices[self.rank:self.total_size:self.num_replicas]
+ indices = indices[self.rank : self.total_size : self.num_replicas]
assert len(indices) == self.num_samples
self.indices = indices
diff --git a/applications/Chat/coati/trainer/utils.py b/applications/Chat/coati/trainer/utils.py
index 7e2cb9c634f7..7811e7365eeb 100644
--- a/applications/Chat/coati/trainer/utils.py
+++ b/applications/Chat/coati/trainer/utils.py
@@ -42,7 +42,6 @@ def is_rank_0() -> bool:
def to_device(x: Any, device: torch.device) -> Any:
-
def _to(t: Any):
if isinstance(t, torch.Tensor):
return t.to(device)
diff --git a/applications/Chat/evaluate/README.md b/applications/Chat/evaluate/README.md
deleted file mode 100644
index 68b03be16a30..000000000000
--- a/applications/Chat/evaluate/README.md
+++ /dev/null
@@ -1,396 +0,0 @@
-# Evaluation
-
-In this directory, we introduce how you can evaluate your model with our pipeline. This pipeline is now available for evaluation of both Chinese and English capability.
-
-## Installation
-
-To start model evaluation, you need to install required packages which listed in `requirements.txt` under `evaluate` folder.
-
-```shell
-pip install -r requirements.txt
-```
-
-## Evaluation Pipeline
-
-The whole evaluation pipeline consists of three methods:
-
-1. `GPT Evaluation`: evaluates model predictions using GPT models.
- - Compare the performance of two different models (battle).
- - Rate the model according to pre-defined metrics using prompting design.
- - Rate the model according to pre-defined metrics with additional reference answer using prompting design.
-2. `Automatic Evaluation`: evaluates model predictions using automatic metrics.
-3. `UniEval`: evaluates model predictions using UniEval models(English only).
-
-### Evaluation Category
-
-Our evaluation pipeline examines the model's capability using 10 categories of questions. The following table introduces each category:
-
-| Evaluation Category | Description |
-| :-----------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
-| Brainstorming | Models are asked to generate a range of creative and diverse ideas according to the question. The capability of creativity is required. |
-| Chat | Models are asked to continue a multi-round dialogue given the roles involved. The capability of understanding, memorizing previous rounds of the dialogue and answering according to the persona provided is required. |
-| Classification | Models are asked to do classification tasks. The capability of accurate classification is required. |
-| Closed QA | Models are asked to answer a closed QA question. The capability of answering questions with limited scope (such as single/multiple choice question) is required. |
-| Extraction | Models are asked to extract information from a given material. The capability of extracting required information is required. |
-| Generation | Models are asked to generate an email, letter, article, etc. The capability of generating texts in a high quality and human-written way is required. |
-| Open QA | Models are asked to answer an open QA question(without context provided). The capability of answering questions with the models' own knowledge base is required. |
-| Roleplay | Models are asked to play the role provided. The capability of engaging in the scenario and effectively interacting with the user is required. |
-| Rewriting | Models are asked to do rewriting tasks such as translation and grammar correction. The capability of rewriting according to different instructions is required. |
-| Summarization | Models are asked to summarize the given paragraph or passage. The capability of summarization is required. |
-
-To better understand each evaluation category, here are some example questions provided.
-
-| Evaluation Category | Chinese Example | English Example |
-| :-----------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
-| Brainstorming | **Example 1:**
请介绍一下人工智能的多个领域。
**Example 2:**
请给出管理家庭财务的 3 个小技巧。
| **Example 1:**
How can I improve my memory? Any useful techniques you can suggest?
**Example 2:**
What are some ways to increase productivity while working from home? |
-| Chat | **Example 1:**
基于以下角色信息完成一段对话。小张是一名新手爱好者,对养鸡有浓厚的兴趣。老李是一名有丰富经验的养鸡大师。
小张:您好,老李,我最近开始对养鸡感兴趣了,想请教您一些问题。
老李:你好,小张,我很乐意帮助你。你想问些什么?
小张:我想知道如何确定鸡的品种和性别?
老李:确切的品种可以通过鸡的外貌特征来确定,而性别一般是通过鸡卵的大小和形状来判断。还有什么问题吗?
小张:
**Example 2:**
基于以下角色信息完成一段对话。小明是一名医生,一位老年病患者想要停药,但他对病情有所忽视并有担忧;王叔叔是老年病患者的儿子,希望能够听取医生的建议。
小明:你好,王叔叔,我了解你想要让你父亲停药。
王叔叔:是的,我父亲已经吃了那么久的药,我担心药物对他的身体会有副作用。
小明: | **Example 1:**
Complete a conversation based on the following character information. Amy is a 30-year-old chef who runs her own restaurant. Jack is a food blogger who specializes in reviewing local restaurants.
Amy: Hi Jack, I heard that you're a food blogger. Nice to meet you.
Jack: Hi Amy, yes I am. Your restaurant has been receiving a lot of good reviews lately.
Amy: Yes, we use only fresh and quality ingredients, and every dish is carefully crafted.
Jack:
**Example 2:**
Complete a dialogue based on the following role information. A: Elementary student B: Teacher
B: Good morning, Student A. Today we're going to learn about addition and subtraction.
A: Teacher, I already know this very well. Why do I need to learn it again?
B: |
-| Classification | **Example 1:**
新闻标题:今日立夏,有一上联,立夏万物并秀,下联怎么对?
请根据以上新闻标题判断新闻所属的分类,你需要从文化,娱乐,体育,财经,房产,教育,科技,旅游,游戏,军事这十类中选择一个答案。
**Example 2:**
新闻标题:赵丽颖很久没有登上微博热搜了,但你们别急,她只是在憋大招而已。
请根据新闻标题判断新闻所属的分类,你需要从文化,娱乐,体育,财经,房产,教育,科技,旅游,游戏,军事这十类中选择一个答案。 | **Example 1:**
Title: Fighting for Love (2020)
Description: Jasmine got obsessed with a man and now he's obsessed with her. Steamy nights, kisses and rules being broken awaits them. She turned his whole world upside down and now he's doing it to hers. In this free fall, can they survive each others love?\"
Based on the above information, determine which genre the work of art belongs to. You can only choose one from \"sport\", \"horror\", \"drama\", \"history\", \"romance\", \"biography\", \"science fiction\", \"comedy\", \"animation\", \"documentary\", \"music\" and \"news\".
**Example2:**
Title: Summer Breeze: The Isley Brothers Greatest Hits Live (2005)
Description: Filmed in the US in 2005 and captured in excellent form led by Ron Isley's vocals and Ernie Isley's hard edged guitar. Virtually every track is a hit including Shout, Who's That Lady, Twist And Shout, Summer Breeze and Harvest For The World.
Based on the above information, determine which genre the work of art belongs to. You can only choose one from \"sport\", \"horror\", \"drama\", \"history\", \"romance\", \"biography\", \"science fiction\", \"comedy\", \"animation\", \"documentary\", \"music\" and \"news\"." |
-| Closed QA | **Example 1:**
请从以下选项中选择正确答案。以下哪个是世界上最高山峰?
A. 长城
B. 泰山
C. 珠穆朗玛峰
D. 黄山
**Example 2:**
请从以下选项中选择一个最佳答案回答下面的问题。问题:非洲最高的山是哪座山?
选项:
A. 麦金利山
B. 喜马拉雅山
C. 乞力马扎罗山 | **Example 1:**
Which of the following options is NOT a primary color?
(a) yellow
(b) blue
(c) orange
(d) red
**Example 2:**
Choose the correct option to complete the following sentence: \"Harry Potter and the Chamber of Secrets\" is the **\_\_\_\_** book in the Harry Potter series.
(A) first
(B) second
(C) third
(D) fourth |
-| Extraction | **Example 1:**
根据以下新闻文本,提取新闻报道时间,例如回答时按照格式“新闻报道时间:2007 年 8 月 10 日”
新闻文本如下:2007-4-7 中新网 4 月 7 日电据中国消防在线消息,4 月 4 日晚上 7 时 30 分左右,湖南长潭高速公路上发生一起 6 车连环相撞失火事故。长株潭三地消防部门共出动消防车 21 台,警力 100 余人。经过消防官兵近 2 个小时奋力扑救,大火被成功扑灭。据初步调查,有 1 人在此次事故中死亡。
**Example 2:**
根据以下新闻文本,提取新闻报道时间,例如回答时按照格式“新闻报道时间:2007 年 8 月 10 日”
新闻文本如下:2014 年 1 月 15 日,据外媒《俄罗斯报》报道称,位于北半球的澳大利亚现在正处于炎热的夏季,而近日也到了高温酷暑的时候,当地时间 1 月 14 日晚,澳大利亚南部一夜间发生至少 250 起火灾。受炎热天气及雷雨天气影响,澳大利亚南部一夜间发生至少 250 起火灾,灾情多集中在维多利亚州。火灾发生后,救援人员立即展开救灾行动。目前,大部分起火点火势已被控制。 | **Example 1:**
Ernest Hemingway, an American literary giant known for his spare and direct writing style, has penned timeless works such as 'The Old Man and the Sea', 'For Whom the Bell Tolls', and 'A Farewell to Arms', which have made a profound impact on the literary world and continue to be widely read and admired today.
Extract the name of the author mentioned above.
**Example 2:**
In the epic fantasy series 'A Song of Ice and Fire', George R.R. Martin weaves a complex web of political intrigue, war, and magic across the fictional continents of Westeros and Essos. Martin's richly developed characters and intricate plotlines have captivated readers worldwide, much like his other acclaimed works such as 'A Clash of Kings' and 'A Storm of Swords'.
Extract the name of the author in the above material. |
-| Generation | **Example 1:**
请撰写一篇文章,介绍如何通过改善生活习惯来预防疾病和延长寿命。
**Example 2:**
请根据以下情节撰写一篇短篇小说:一名年轻人被困在一个荒岛上,他必须想办法生存下去直到被救援。但他很快发现自己并不孤单。 | **Example 1:**
Write a descriptive paragraph about an island to relax and unwind, including details about the location and atmosphere.
**Example 2:**
Can you help me write a persuasive email to my colleagues encouraging them to participate in a charitable fundraising event? |
-| Open QA | **Example 1:**
请问万有引力定律由谁提出的?
**Example 2:**
哪些国家参与了第一次世界大战? | **Example 1:**
What are the four basic tastes of the human palate?
**Example 2:**
Who painted the The Scream? |
-| Rewriting | **Example 1:**
请将以下句子改为正确的语序。
生日快乐你祝他了吗?
**Example 2:**
将以下文本翻译成英语:
“这个周末我要去海边玩” | **Example 1:**
Please translate the following sentences, which are a mixture of Chinese and English, into full English.
我需要买一些 healthy snacks,比如 nuts 和 dried fruits,作为我的 office 的午餐.
**Example 2:**
Please rewrite the sentence using an inverted sentence structure.
We won't begin our journey until the sun sets. |
-| Roleplay | **Example 1:**
我想让你担任 Android 开发工程师面试官。我将成为候选人,您将向我询问 Android 开发工程师职位的面试问题。我希望你只作为面试官回答。不要一次写出所有的问题。我希望你只对我进行采访。问我问题,等待我的回答。不要写解释。像面试官一样一个一个问我,等我回答。我的第一句话是“面试官你好”。
**Example 2:**
我想让你扮演讲故事的角色。你会想出引人入胜、富有想象力和吸引观众的有趣故事。它可以是童话故事、教育故事或任何其他类型的有潜力的故事以吸引人们的注意力和想象力。根据目标受众,您可以为您的讲故事环节选择特定的主题或主题,例如,如果是儿童,那么您可以谈论动物;如果是成人,那么基于历史的故事可能会更好地吸引他们等。我的第一个请求是我需要一个关于毅力的有趣故事。 | **Example 1:**
Assume the role of a marriage counselor. Develop a series of communication exercises for a couple who are experiencing difficulties in their relationship. These exercises should promote active listening, empathy, and effective expression of emotions. Your first assignment is to provide a set of three exercises that focus on resolving conflicts and rebuilding trust.
**Example 2:**
I want you to act as a travel agent. I will tell you my desired destination, travel dates, and budget, and it will be your job to suggest the best travel itinerary for me. Your recommendations should include the best transportation options, hotel accommodations, and any popular tourist attractions nearby. My first request is "I want to plan a trip to Tokyo for a week, with a budget of $2000. I want to explore the culture and food of the city." |
-| Summarization | **Example 1:**
请简要总结概括以下段落材料。
当地时间 29 日,泰国卫生部通报,新增 143 名新冠肺炎确诊病例和 1 名死亡病例。截止到当地时间 29 日上午,泰国累计确诊病例 1388 例,其中泰国籍 1172 例,非泰国籍 216 例。死亡病例累计 7 例。(原题为《泰国新增 143 例新冠肺炎确诊病例累计确诊 1388 例》)
**Example 2:**
请简要总结概括以下段落材料。
近期,参与京雄高铁站站房建设的中铁十二局,因在施工过程中存在环境违法行为被雄安新区公开通报。通报发出后,引起社会广泛关注。近日,人民网记者从雄安新区相关部门及中铁十二局获悉,新区有关部门已经集中约谈了中铁十二局等 24 个参与雄安建设的项目单位。对于约谈内容和结果,中铁十二局有关宣传负责人回应:“具体内容不清楚,最好找雄安新区相关部门了解情况。”新区有关部门负责人表示,此前涉及的环境违法行为,中铁十二局已基本整改到位,但约谈内容和结果暂不公开,接下来,将按部就班推进环境治理工作。(原题为《雄安新区:中铁十二局涉环境违法已基本整改到位》) | **Example 1:**
The 21 year-old-woman was treated by paramedics after the kitchen fire in Botfield Road in Shifnal, Shropshire. West Mercia Police said it is treating Wednesday morning's incident as arson and are appealing for any witnesses to contact them.The 50-year-old man has been arrested on suspicion of arson with intent to endanger life. For more on this and other stories from Shropshire.
Please briefly summarize the above material within 20 words.
**Example 2:**
South Wales Police were called to a property in Heolgerrig, Merthyr Tydfil, at about 13:40 BST on Sunday. The child was airlifted to Prince Charles Hospital but died shortly afterwards. Police are investigating the circumstances surrounding the incident and have appealed for witnesses. The girl's family are being supported by specially trained officers.
Please briefly summarize the above material within 20 words. |
-
-### Evaluation Metrics
-
-#### GPT Evaluation
-
-GPT evaluation uses GPT models to evaluate the prediction of different models and different pre-defined evaluation metrics are applied to different categories. The following table shows the 11 pre-defined evaluation metrics both in Chinese and English:
-
-| Evaluation Metric | Prompt Words | CoT(Chain-of-Thought) |
-| :----------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
-| 语言组织
(Language organization) | 语言组织(1-5):答案语言是否流畅、连贯,使用正确的语法,具有一定逻辑性,使用恰当的连接词、过渡词等等。Language organization (1-5): whether the answer language is fluent and coherent, uses correct grammar, has a certain logic, uses appropriate connecting words, transition words, etc. | 1. 阅读答案,并检查是否有语法错误、用词不当或其他显著的错误。
2. 检查答案是否具有逻辑性,能够按照合理的顺序传达信息并且能够自圆其说
3. 确定答案是否与问题或主题相关,并且能够传达清晰的信息。
4. 检查答案是否连贯,是否使用适当的转换和过渡来保持句子和段落之间的连贯性。
5. 检查答案是否具有明确的结构和组织方式,使得读者可以轻松理解信息的层次和结构。
6. 根据以上因素综合评估答案的语言组织,并给出一个 1 到 5 的分数,其中 5 表示语言组织非常好,而 1 表示语言组织非常差。1. Read the answers and check for grammatical errors, poor word choice, or other significant mistakes.
2. Check that the answer is logical, conveys the information in a logical order, and is self-explanatory.
3. Determine if the answer is relevant to the question or topic and conveys a clear message.
4. Check that the answer is coherent and that appropriate transitions and switches are used to maintain coherence between sentences and paragraphs.
5. Check that the answer is clearly structured and organized in such a way that the reader can easily understand the hierarchy and structure of the information.
6. Evaluate the linguistic organization of the answer based on a combination of the above factors and give a score of 1 to 5, where 5 indicates very good linguistic organization and 1 indicates very poor linguistic organization. |
-| 切题
(Relevance) | 切题(1-5):答案内容是否切题,不答非所问,并且严格遵照题目要求。Relevance (1-5): whether the content of the answer is relevant to the topic, does not answer the wrong question, and strictly follows the requirements of the topic. | 1. 阅读题目,确定题目所问的问题是什么,以及需要回答哪些方面的问题。
2. 阅读答案,确认答案是否直接回答了题目所问的问题。
3. 检查答案是否严格遵照了题目的要求,包括答题方式、答题长度、答题格式等等。
4. 根据以上因素综合评估答案的切题程度,并给出一个 1 到 5 的分数,其中 5 表示答案非常切题,而 1 表示答案完全没有切题。1. Read the question to determine what the question asks and what aspects of the question need to be answered.
2. Read the answers to make sure that they directly answer the question asked.
3. Check that the answer follows the requirements of the question, including the way it is answered, the length of the answer, the format of the answer, etc.
4. Evaluate how relevant the answer is based on the above factors and give a score of 1 to 5, where 5 means the answer is very relevant and 1 means the answer is not relevant at all. |
-| 创意性
(Creativity) | 创意性(1-5):某些头脑风暴问题可能需要答案具有创意,提出新的思路。Creativity (1-5): Some brainstorming questions may require answers that are creative and suggest new ideas. | 1. 仔细阅读所提供的头脑风暴问题,确保你理解问题的要点和背景。
2. 根据你的知识和经验,判断所提供的答案是否可行。如果答案不可行,则创意性评分可能会受到影响。
3. 考虑答案中是否包含新颖的想法或独特的思路。答案可能与已知的解决方案有所重叠,但仍然可以被认为是有创意的,只要它提供了新的角度或方法来解决问题。
4. 根据答案的创意性,给出一个 1 到 5 的评分。如果答案缺乏创意,则应给出一个较低的评分。如果答案具有创意并提供了新的思路,应给出一个较高的评分。1. Read the provided brainstorming questions carefully to make sure you understand the gist and context of the questions.
2. Based on your knowledge and experience, determine if the answers provided are feasible. If the answer is not feasible, the creativity score may be affected.
3. Consider whether the answer contains novel ideas or unique thoughts. An answer may overlap with a known solution and still be considered creative, as long as it offers a new perspective or approach to the problem.
4. Give a score of 1 to 5 depending on the creativity of the answer. If the answer lacks creativity, a lower score should be given. If the answer is creative and provides a new idea, a higher score should be given. |
-| 实用性
(Practicality) | 实用性(1-5):某些头脑风暴问题可能需要答案提出实用的建议或解决方法。Practicality (1-5): Some brainstorming questions may require answers to suggest practical suggestions or solutions. | 1. 仔细阅读所提供的头脑风暴问题,确保你理解问题的要点和背景。
2. 根据你的知识和经验,判断所提供的答案是否可行。如果答案不可行,则实用性评分可能会受到影响。
3. 考虑答案中提出的建议或解决方法是否实用并可行。答案可能看起来很好,但如果无法实现或应用,则实用性评分可能会受到影响。
4. 根据答案的实用性,给出一个 1 到 5 的评分。如果答案缺乏实用性,则应给出一个较低的评分。如果答案提出了实用的建议或解决方法,并且可以很好地解决问题,则应给出一个较高的评分。1. Read the provided brainstorming questions carefully to make sure you understand the gist and context of the questions.
2. Based on your knowledge and experience, determine if the answers provided are feasible. If the answer is not feasible, the practicality score may be affected.
3. Consider whether the suggestions or solutions presented in the answer are practical and workable. The answer may look good, but if it cannot be implemented or applied, the practicality score may be affected.
4. Give a score of 1 to 5 depending on the practicality of the answer. If the answer lacks practicality, a lower score should be given. If the answer makes a practical suggestion or solution and solves the problem well, a higher score should be given. |
-| 正确性
(Correctness) | 正确性(1-5):正确性(1-5):答案是否正确。 Correctness (1-5): whether the answer is correct or not. | 1. 仔细阅读题目,尝试自己回答该问题。
2. 检查答案的准确性。您可以使用已知的事实或研究来验证答案是否正确。如果答案是正确的,则可以将正确性得分为 5 分。如果答案是部分正确的,则可以给予适当的得分,例如 2 分、3 分或 4 分。如果答案完全不正确,则只得 1 分。
1. Read the question carefully and try to answer the question yourself.
2. Check the correctness of the answer. You can use known facts or research to verify that the answer is correct. If the answer is correct, you can give a score of 5 for correctness. If the answer is partially correct, an appropriate score, such as 2, 3, or 4, may be given. If the answer is completely incorrect, only 1 point is awarded. |
-| 自然
(Naturalness) | 自然(1-5):答案是否自然,并且符合问题给定的身份。Naturalness (1-5): whether the answer is natural and fits the identity given by the question. | 1. 阅读题目,确定题目提供的身份信息。
2. 检查答案内容是否符合题目给定的身份。
3. 根据以上因素,对该回答的自然性进行打分,分数从 1 到 5,其中 1 表示不自然,5 表示非常自然,并符合问题给定的身份。1. Read the question and determine the identity information provided in the question.
2. Check whether the content of the answer matches the identity given in the question.
3. Based on the above factors, score the naturalness of the response on a scale from 1 to 5, where 1 means unnatural and 5 means very natural and in accordance with the identity given in the question. |
-| 参与感
(Engagingness) | 参与感(1-5):答案是否对前面的对话内容做出了恰当的反应,是否理解对话的语境和背景。Engagingness (1-5): whether the answer responds appropriately to the content of the preceding conversation and whether it understands the context and background of the conversation. | 1. 阅读题目,确定对话的语境和背景。
2. 检查答案是否充分理解对话的语境和背景,能否自然地融入到对话中而不显得突兀。
3. 根据以上因素,对该回答的参与感进行打分,分数从 1 到 5,其中 1 表示没有参与感,5 表示非常有参与感,并且恰当地理解了对话的语境和背景。1. Read the questions to determine the context and background of the dialogue.
2. Check that the answer fully understands the context and background of the conversation and that it fits naturally into the conversation without seeming abrupt.
3. Based on the above factors, rate the response's engagement on a scale from 1 to 5, where 1 means not engaged and 5 means very engaged and appropriately understands the context and background of the conversation. |
-| 合理性
(Reasonableness) | 合理性(1-5):答案是否能够与前面的对话内容形成逻辑上的衔接,是否符合常理,能否在这个上下文中合理存在。Reasonableness (1-5): Whether the answer can form a logical connection with the content of the previous dialogue, whether it is consistent with common sense, and whether it can reasonably exist in this context. | 1. 阅读题目,确定对话的主题以及问题期望的回答方向。
2. 判断答案是否能够与前面的对话内容形成逻辑上的衔接,是否符合常理,能否在这个上下文中合理存在。
3. 根据以上因素,对该回答的合理性进行打分,分数从 1 到 5,其中 1 表示不合理,5 表示非常合理,并且能够与前面的对话内容形成逻辑上的衔接,并符合常理。1. Read the question and determine the topic of the conversation and the direction the question expects the answer to go.
2. Determine whether the answer can be logically connected to the preceding conversation, whether it makes common sense, and whether it can reasonably exist in this context.
3. Based on the above factors, rate the reasonableness of the answer on a scale from 1 to 5, where 1 means unreasonable and 5 means very reasonable and able to form a logical connection with the preceding dialogue content and consistent with common sense. |
-| 多样性
(Diversity) | 多样性(1-5):答案使用语言是否优美,具有有一定的创造性和想象力。然而,回答也应该保持合理和适度,不要过于夸张或离题。Diversity (1-5): Whether the answers use beautiful language and have some creativity and imagination. However, answers should also be kept reasonable and moderate, not overly exaggerated or off-topic. | 1. 仔细阅读整个回答,确保完全理解回答所表达的内容和主题。
2. 在阅读回答的同时,注意语言的质量,例如措辞是否正确,语言是否生动等。
3. 检查回答的创造性和想象力,看看回答是否能够吸引人阅读下去。
4. 检查回答的合理性和适度,看看回答是否夸张或离题。5. 将多样性的评分打分在 1 到 5 之间,5 分表示回答的质量很好,能够吸引人阅读,1 分表示回答的内容生硬或者有离题的问题。1. Read the entire response carefully to ensure that you fully understand the content and theme expressed in the response.
2. While reading the response, pay attention to the quality of the language, such as whether the wording is correct and the language is vivid.
3. Check the creativity and imagination of the response to see if the response is engaging to read on.
4. Check the reasonableness and appropriateness of the responses to see if the responses are exaggerated or off-topic.
5. Rate the diversity on a scale of 1 to 5, with a 5 indicating a good quality response that is engaging to read and a 1 indicating a raw response or a question that is off-topic. |
-| 保真度
(Fidelity) | 保真度(1-5):答案是否能够严格遵守角色的设定回答给定的请求。Fidelity (1-5): whether the answer is able to answer the given request in strict compliance with the role setting. | 1. 仔细阅读问题,了解角色在问题中的设定和表现,包括职业、背景、观点、性格等方面。
阅读题目的请求,确认回答请求时需要注意的细节。
3. 对比提供的回答与该角色的设定,评估回答是否能够严格遵守角色的设定。
4. 结合以上评估结果给出保真度的评分,范围从 1 到 5 分,其中 1 分表示回答与角色设定完全不符,5 分表示回答完全符合角色设定且满足给定请求。1. Read the question carefully to understand how the character is set up and represented in the question, including aspects such as occupation, background, point of view, and personality.
2. Read the question's request and confirm the details that need to be taken into account when answering the request.
3. Compare the provided answer with the setting of the role and assess whether the answer can strictly adhere to the setting of the role.
4. Combine the results of the above assessment to give a fidelity score ranging from 1 to 5, where a score of 1 means that the response does not match the persona at all, and a score of 5 means that the response fully complies with the persona and satisfies the given request. |
-| 简明扼要
(Conciseness) | 简明扼要(1-5):答案是否简明扼要,没有冗余内容。Conciseness (1-5): answers should be concise and without redundant content. | 1. 阅读题目,提取出材料的重点。
2. 阅读该总结,并注意其中的主要观点和信息。
3. 评估总结的长度。一个简明扼要的总结通常应该在几句话或几段文字内传达关键信息,而不是冗长的段落或文章。
4. 检查总结是否包含与主要观点无关的信息或冗余信息。
5. 确定总结涵盖了材料中的关键信息,并且没有忽略任何重要细节。
6. 给总结打出 1-5 的分数,其中 5 表示总结简明扼要,没有冗余内容,而 1 表示总结冗长或包含不必要的信息,难以理解或记忆。根据您的判断,打出适当的得分。1. Read the title and extract the main points of the material.
2. Read the summary and note the main ideas and messages in it.
3. Assess the length of the summary. A concise summary should usually convey key information within a few sentences or paragraphs, rather than lengthy paragraphs or essays.
4. Check that the summary does not contain information that is not relevant to the main ideas or that is redundant.
5. Make sure that the summary covers the key information in the material and that no important details have been omitted.
6. Rate the summary on a scale of 1-5, where 5 means the summary is concise and free of redundancy, and 1 means the summary is lengthy or contains unnecessary information that is difficult to understand or remember. Based on your judgment, assign the appropriate score. |
-
-GPT models evaluate the quality of model predictions based on the given prompt words and gives a score between 1-5.
-
-> **NOTE 1:** Even for the same metric, the details of its prompt words and CoT(Chain-of-Thought) can differ based on which category you want to evaluate. For example, prompt words for metric `correctness` showed here is "Whether the answer is correct or not."(this is for category `classification`), but for category `extraction`, prompt words can be "Answers should extract the required information accurately and should not contain any incorrect or misleading information." You can find all the prompt words and CoT(Chain-of-Thought) in `prompt/evaluation_prompt`.
-
-> **NOTE 2:** To add customized metrics, you can refer to [FAQ](#faq).
-
-#### Automatic Evaluation
-
-Automated metrics evaluate the capability of a model by comparing model predictions with reference answers.
-There are two ways to obtain reference answers:
-
-- For instruction coming from human-designed problems, the reference answers are generated by GPT-3.5, such as roleplay, chat.
-- For instruction related with classic NLP problems, the reference answers are collected from open-sourced dataset with target answers, such as classification, extraction, summarization.
-
-There are 6 types of automatic evaluation metrics listed in the table below:
-
-| Automatic Evaluation Metric | Description |
-| :---------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
-| BLEU-n | Measure the accuracy between prediction and reference.
BLEU-1 (Unigram) evaluates accuracy in word level.
BLEU-n (n-gram) evaluate the fluency in sentence level. |
-| ROUGE | ROUGE-N measures the number of matching n-grams between prediction and reference.
ROUGE-L measures the number of matching longest common subsequence (LCS) between prediction and reference. |
-| Distinct | Measure the diversity of generation text by counting the unique n-grams. |
-| BERTScore | Measure the semantic similarity between tokens of predictions and references with BERT. |
-| Precision
Recall
F1 Score | Measure the number of overlaps between prediction and reference (design for classification and extraction categories). |
-| CHRF | Measure the similarity of character n-grams between prediction and reference. |
-
-#### UniEval Evaluation
-
-UniEval converts all evaluation tasks of different dimensions(metrics) into Boolean QA problems and utilize the model to answer with “Yes” or “No”. Compared with similarity-based metrics such as ROUGE and BLEU, UniEval can achieve a more comprehensive evaluation. In addition, UniEval also demonstrates its ability to transfer to unseen dimensions and tasks.
-
-In our evaluation pipeline, two pre-trained UniEval evaluators are used. One is [unieval-sum](https://huggingface.co/MingZhong/unieval-sum) and the other is [unieval-dialog](https://huggingface.co/MingZhong/unieval-dialog). The two models can be used for the 3 tasks, `summarization`, `dialogue` and `data2text`. Each task has different evaluation dimensions.
-
-| UniEval Model | Task | Dimension(Metric) |
-| :------------: | :------------ | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
-| unieval-sum | summarization | coherence: whether the summary is coherent
consistency: whether the claim is consistent with the given document
fluency: whether the paragraph is fluent
relevance: whether the summary is relevant to the reference |
-| unieval-sum | data2text | naturalness: whether the utterance is fluent
informativeness: whether the utterance is informative according to the reference |
-| unieval-dialog | dialogue | naturalness: whether the response is natural in the dialogue
coherence: whether the response is coherent in the dialogue history
understandability: whether the response is understandable in the dialogue |
-
-> **NOTE 1:** Task "data2text" uses the same model as task "summarization".
-
-> **NOTE 2:** In UniEval paper, the `unieval-sum` model demonstrates the best transfer ability and so you can evaluate your customized metric with this model. Details of adding customized metrics can be found in [FAQ](#faq).
-
-> **NOTE 3:** We consider not including all metrics provided in UniEval in our pipeline because the data structure and content of the instructions we want to evaluate are not suitable for direct use of some UniEval metrics.
-
-## Evaluation Process
-
-### Data Format
-
-#### Target Answers / Predictions
-
-A JSON file contains one list. Each element in the list is a target answer / prediction record for one instruction / question.
-An element should have the following fields:
-
-- `category` (str, compulsory): The category of the instruction / question.
-- `instruction` (str, compulsory): The instruction / question for the LLM.
-- `input` (str, optional): The additional context of the instruction / question.
-- `output` (str, optional): The sample output of the instruction (default: GPT-3.5).
-- `target` (str, optional): The target answer for the instruction.
-- `id` (int, compulsory): The ID of the instruction / question.
-
-If the `input` has a target answer, the `output` can be empty. Otherwise, we generate answers from GPT-3.5 as the `output`, and the `target` field is empty.
-
-Example:
-
-```json
-[
- {
- "category": "brainstorming",
- "instruction": "请介绍一下人工智能的多个领域。",
- "input": "",
- "output": "{GPT-3.5 Answers}",
- "target": "",
- "id": 1
- },
- {
- "category": "classification",
- "instruction": "新闻标题:为什么电影《倩女幽魂》中燕赤霞一个道士却拿着金刚经?请根据新闻标题判断新闻所属的分类,你需要从文化,娱乐,体育,财经,房产,教育,科技,旅游,游戏,军事这十类中选择一个答案。",
- "input": "",
- "output": "",
- "target": "{target answer}",
- "id": 2
- }
-]
-```
-
-#### Model Answers / Predictions
-
-A JSON file contains one list. Each element in the list is a model answer / prediction record for one instruction / question.
-
-An element should have the following fields:
-
-- `category` (str, compulsory): The category of the instruction / question.
-- `instruction` (str, compulsory): The instruction / question for the LLM.
-- `input` (str, optional): The additional context of the instruction / question.
-- `output` (str, compulsory): The output from the LLM.
-- `target` (str, optional): The target answer for the instruction.
-- `id` (int, compulsory): The ID of the instruction / question.
-
-Example:
-
-```json
-[
- {
- "category": "brainstorming",
- "instruction": "请介绍一下人工智能的多个领域。",
- "input": "",
- "output": "{Model Answers / Predictions}",
- "target": "",
- "id": 1
- },
- {
- "category": "classification",
- "instruction": "新闻标题:为什么电影《倩女幽魂》中燕赤霞一个道士却拿着金刚经?请根据新闻标题判断新闻所属的分类,你需要从文化,娱乐,体育,财经,房产,教育,科技,旅游,游戏,军事这十类中选择一个答案。",
- "input": "",
- "output": "{Model Answers / Predictions}",
- "target": "{target answer}",
- "id": 2
- }
-]
-```
-
-### Prompt
-
-#### Battle Prompt
-
-The following is the Chinese battle prompt. In the battle prompt, the question and answers from two different models are fed into the prompt template. You can find example battle prompt files for Chinese and English in `prompt/battle_prompt`.
-
-```json
-{
- "id": 1,
- "system_prompt": "你是一个检查回答质量的好助手。",
- "prompt_template": "[问题]\n{question}\n\n[1号AI助手的答案]\n{answer_1}\n\n[1号AI助手答案终止]\n\n[2号AI助手的答 案]\n{answer_2}\n\n[2号AI助手答案终止]\n\n[要求]\n{prompt}\n\n",
- "prompt": "我们需要你评价这两个AI助手回答的性能。\n请对他们的回答的有用性、相关性、准确性、详细程度进行评分。每个AI助手都会得到一个1到10分的总分,分数越高表示整体表现越好。\n请首先输出一行,该行只包含两个数值,分别表示1号和2号AI助手的分数。这两个分数之间要有一个空格。在随后的一行中,请对你的评价作出全面的解释,避免任何潜在的偏见,并确保AI助手回答的顺序不会影响您的判断。"
-}
-```
-
-#### Evaluation Prompt
-
-The following is an example of a Chinese GPT evaluation prompt. In an evaluation prompt, you should define your metrics in `metrics` and provide CoT(Chain-of-Thought) in `CoT`. You can find example evaluation prompt files for Chinese and English in `prompt/evaluation_prompt`.
-
-```json
-{
- "brainstorming": {
- "id": 1,
- "category": "brainstorming",
- "metrics": {
- "language organization": "语言组织(1-5):答案语言是否流畅、连贯,使用正确的语法,具有一定逻辑性,使用恰当的连接词、过渡词等等。"
- },
- "CoT": {
- "language organization": "1. 阅读答案,并检查是否有语法错误、用词不当或其他显著的错误。\n2. 检查答案是否具有逻辑性,能够按照合理的顺序传达信息并且能够自圆其说。\n3. 确定答案是否与问题或主题相关,并且能够传达清晰的信息。\n4. 检查答案是否连贯,是否使用适当的转换和过渡来保持句子和段落之间的连贯性。\n5. 检查答案是否具有明确的结构和组织方式,使得读者可以轻松理解信息的层次和结构。\n6. 根据以上因素综合评估答案的语言组织,并给出一个1到5的分数,其中5表示语言组织非常好,而1表示语言组织非常差。\n\n语言组织:"
- },
- "prompt": "你是一个好助手。请你为下面“头脑风暴”问题的答案打分。\n\n问题如下:\n\n{question}\n\n答案如下:\n\n{answer}\n\n评分的指标如下:\n\n{metric}\n\n请你遵照以下的评分步骤:\n\n{steps}"
- }
-}
-```
-
-`"metrics"`: the metrics that can be used in GPT evaluation. This field determines which metrics can be added to your config file.
-
-`"CoT"`: evaluation steps you prompt to GPT models for each metric defined in `"metrics"`.
-
-### Evaluation
-
-#### Configuration
-
-The following is an example of a Chinese config file. The configuration file can control how the pipeline evaluates the model. You need to specify GPT evaluation metrics, automatic metrics and UniEval metrics in key `GPT`, `Metrics` and `UniEval`(English only). You can find an example English config file in `config`.
-
-```json
-{
- "language": "en",
- "path_for_UniEval": {
- "summarization": "path to unieval-sum model",
- "dialogue": "path to unieval-dialog model",
- "data2text": "path to unieval-sum model"
- },
- "category": {
- "brainstorming": {
- "GPT": ["relevance", "creativity", "practicality", "reasonableness"],
- "Metrics": ["Distinct"],
- "UniEval": [
- "summarization-fluency",
- "data2text-naturalness",
- "data2text-informativeness"
- ]
- },
- "chat": {
- "GPT": ["relevance", "naturalness", "engagingness", "reasonableness"],
- "Metrics": ["Distinct"],
- "UniEval": [
- "dialogue-naturalness",
- "dialogue-coherence",
- "dialogue-understandability"
- ]
- }
- }
-}
-```
-
-`"language"`: the language used to evaluate the model capability. We only support Chinese `"cn"` for now.
-
-`"path_for_UniEval"`: path to the UniEval model.
-
-`"category"`: the category/categories needed to evaluate the model capability.
-
-`"GPT"`: the metrics you want to use for GPT evaluation.
-
-`"Metrics"`: the metrics you want to use for automatic metrics evaluation.
-
-`"UniEval"`: the metrics you want to use for UniEval metrics evaluation. The metric has to be in the `"{task}-{metric}"` format because different tasks have same metrics such as naturalness and coherence.
-
-You can remove the key such as `"Metrics"` to skip evaluating answers using its corresponding evaluation metrics.
-
-You can create your config file based on available settings listed in following table.
-
-| "category" | "GPT" | "Metrics" | "UniEval" |
-| :--------------: | :---------------------: | :---------: | :--------------------------: |
-| "brainstorming" | "language organization" | "BLEU" | "dialogue-naturalness" |
-| "chat" | "relevance" | "ROUGE" | "dialogue-coherence" |
-| "classification" | "creativity" | "Distinct" | "dialogue-understandability" |
-| "closed_qa" | "practicality" | "BERTScore" | "data2text-naturalness" |
-| "extraction" | "correctness" | "Precision" | "data2text-informativeness" |
-| "generation" | "naturalness" | "Recall" | "summarization-coherence" |
-| "open_qa" | "engagingness" | "F1 score" | "summarization-consistency" |
-| "rewriting" | "reasonableness" | "CHRF" | "summarization-fluency" |
-| "roleplay" | "diversity" | | "summarization-relevance" |
-| "summarization" | "fidelity" | | |
-| | "conciseness" | | |
-
-> **NOTE:** For categories which don't have standard answers such as `brainstorming`, you should avoid using automatic metrics such as `BLEU` and `ROUGE` which are based on similarity measures and you should use `Distinct` instead in your config file.
-
-#### Evaluate
-
-After setting the configuration file, you can evaluate the model using `eval.py`. If you want to make comparisons between answers of two different models, you should specify two answer files in the argument `answer_file_list` and two model names in the argument `model_name_list`. If you want to evaluate one answer file, the length of both `answer_file_list` and `model_name_list` should be 1 and the program will perform evaluation using automatic metrics and GPT models.
-
-An example script is provided as follows:
-
-```shell
-python eval.py \
- --config_file "path to the config file" \
- --battle_prompt_file "path to the prompt file for battle" \
- --gpt_evaluation_prompt_file "path to the prompt file for gpt evaluation" \
- --target_file "path to the target answer file" \
- --answer_file_list "path to the answer files of at most 2 models" \
- --model_name_list "the names of at most 2 models" \
- --gpt_model "which GPT model to use for evaluation" \
- --save_path "path to save results" \
- --openai_key "your openai key" \
-```
-
-If you want GPT evaluation with reference, you can add an argument `--gpt_with_reference`.
-
-## FAQ
-
-How can I add a new GPT evaluation metric?
-
-For example, if you want to add a new metric `persuasiveness` into category `brainstorming`, you should add the metric definition and its corresponding CoT(Chain-of-thought) in the evaluation prompt file in `prompt/evaluation_promt`. The CoT can be generated using ChatGPT. You can prompt ChatGPT to generate evaluation steps for the new metric.
-
-```json
-{
- "brainstorming": {
- "id": 1,
- "category": "brainstorming",
- "metrics": {
- "persuasiveness": "persuasiveness(1-5):a short description for persuasiveness"
- },
- "CoT": {
- "persuasiveness": "CoT for persuasiveness\n\npersuasiveness:"
- },
- "prompt": "You are a good assistant. Please rate the given answer to the \"brainstorming\" question below.\n\nThe question is as follows:\n\n{question}\n\nThe answer is as follows:\n\n{answer}\n\nThe metric for evaluation is as follows:\n\n{metric}\n\nYou should follow the following evaluation steps:\n\n{steps}"
- }
-}
-```
-
-
-
-How can I add a new UniEval evaluation metric?
-
-For example, if you want to add a new metric `persuasiveness` into task `data2text`, you should add a Boolean QA question about the metric in function `add_question` in `unieval/utils.py`. Please do note that how effectively the model would evaluate this metric is unknown and you may need some experiments to test whether the model is capable of evaluating this metric.
-
-```python
-if task == 'data2text':
- if dimension == 'persuasiveness':
- cur_input = 'question: Is this a persuasive utterence utterance: ' + output[i]
-```
-
-
-
-## To Do
-
-- [x] Add evaluation for English capability
-- [x] Support UniEval
-- [x] Support GPT-4 evaluation
-- [x] Support GPT evaluation with reference
-
-## Citations
-
-```bibtex
-@misc{vicuna2023,
- title = {Vicuna: An Open-Source Chatbot Impressing GPT-4 with 90\%* ChatGPT Quality},
- url = {https://vicuna.lmsys.org},
- author = {Chiang, Wei-Lin and Li, Zhuohan and Lin, Zi and Sheng, Ying and Wu, Zhanghao and Zhang, Hao and Zheng, Lianmin and Zhuang, Siyuan and Zhuang, Yonghao and Gonzalez, Joseph E. and Stoica, Ion and Xing, Eric P.},
- month = {March},
- year = {2023}
-}
-
-@misc{liu2023geval,
- title={G-Eval: NLG Evaluation using GPT-4 with Better Human Alignment},
- author={Yang Liu and Dan Iter and Yichong Xu and Shuohang Wang and Ruochen Xu and Chenguang Zhu},
- year={2023},
- eprint={2303.16634},
- archivePrefix={arXiv},
- primaryClass={cs.CL}
-}
-
-@misc{zhong2022unified,
- title={Towards a Unified Multi-Dimensional Evaluator for Text Generation},
- author={Ming Zhong and Yang Liu and Da Yin and Yuning Mao and Yizhu Jiao and Pengfei Liu and Chenguang Zhu and Heng Ji and Jiawei Han},
- year={2022},
- eprint={2210.07197},
- archivePrefix={arXiv},
- primaryClass={cs.CL}
-}
-```
diff --git a/applications/Chat/evaluate/config/config_cn.json b/applications/Chat/evaluate/config/config_cn.json
deleted file mode 100644
index 023f16bef31c..000000000000
--- a/applications/Chat/evaluate/config/config_cn.json
+++ /dev/null
@@ -1,204 +0,0 @@
-{
- "language": "cn",
- "category": {
- "brainstorming": {
- "GPT": [
- "language organization",
- "relevance",
- "creativity",
- "practicality",
- "reasonableness"
- ],
- "Metrics": [
- "Distinct"
- ]
- },
- "chat": {
- "GPT": [
- "language organization",
- "naturalness",
- "engagingness",
- "fidelity"
- ],
- "Metrics": [
- "Distinct"
- ]
- },
- "classification": {
- "GPT": [
- "relevance",
- "correctness"
- ],
- "Metrics": [
- "Precision",
- "Recall",
- "F1 score",
- "CHRF"
- ]
- },
- "closed_qa": {
- "GPT": [
- "relevance",
- "correctness"
- ],
- "Metrics": [
- "BLEU",
- "ROUGE",
- "BERTScore",
- "CHRF"
- ]
- },
- "extraction": {
- "GPT": [
- "relevance",
- "correctness"
- ],
- "Metrics": [
- "Precision",
- "Recall",
- "F1 score",
- "CHRF"
- ]
- },
- "generation": {
- "GPT": [
- "language organization",
- "relevance",
- "diversity"
- ],
- "Metrics": [
- "BLEU",
- "ROUGE",
- "BERTScore"
- ]
- },
- "logical_reasoning": {
- "GPT": [
- "correctness",
- "relevance",
- "reasonableness"
- ],
- "Metrics": [
- "BLEU",
- "ROUGE",
- "BERTScore",
- "CHRF"
- ]
- },
- "open_qa": {
- "GPT": [
- "language organization",
- "relevance",
- "correctness"
- ],
- "Metrics": [
- "Distinct"
- ]
- },
- "rewriting": {
- "GPT": [
- "language organization",
- "relevance",
- "correctness"
- ],
- "Metrics": [
- "BLEU",
- "ROUGE",
- "BERTScore"
- ]
- },
- "roleplay": {
- "GPT": [
- "language organization",
- "relevance",
- "fidelity",
- "creativity"
- ],
- "Metrics": [
- "Distinct"
- ]
- },
- "summarization": {
- "GPT": [
- "language organization",
- "relevance",
- "correctness",
- "conciseness"
- ],
- "Metrics": [
- ]
- },
- "Finance": {
- "GPT": [
- "relevance",
- "correctness"
- ],
- "Metrics": [
- ]
- },
- "Law": {
- "GPT": [
- "relevance",
- "correctness"
- ],
- "Metrics": [
- ]
- },
- "Education": {
- "GPT": [
- "relevance",
- "correctness"
- ],
- "Metrics": [
- ]
- },
- "Medical": {
- "GPT": [
- "relevance",
- "correctness"
- ],
- "Metrics": [
- ]
- },
- "STEM": {
- "GPT": [
- "relevance",
- "correctness"
- ],
- "Metrics": [
- ]
- },
- "SocialScience": {
- "GPT": [
- "relevance",
- "correctness"
- ],
- "Metrics": [
- ]
- },
- "Humanity": {
- "GPT": [
- "relevance",
- "correctness"
- ],
- "Metrics": [
- ]
- },
- "Other": {
- "GPT": [
- "relevance",
- "correctness"
- ],
- "Metrics": [
- ]
- },
- "ethics": {
- "GPT": [
- "relevance",
- "correctness"
- ],
- "Metrics": [
- ]
- }
- }
-}
diff --git a/applications/Chat/evaluate/config/config_en.json b/applications/Chat/evaluate/config/config_en.json
deleted file mode 100644
index c964122dd6d6..000000000000
--- a/applications/Chat/evaluate/config/config_en.json
+++ /dev/null
@@ -1,283 +0,0 @@
-{
- "language": "en",
- "path_for_UniEval": {
- "summarization": "path to unieval-sum",
- "dialogue": "path to unieval-dialog",
- "data2text": "path to unieval-sum"
- },
- "category": {
- "brainstorming": {
- "GPT": [
- "language organization",
- "relevance",
- "creativity",
- "practicality",
- "reasonableness"
- ],
- "Metrics": [
- "Distinct"
- ],
- "UniEval": [
- "summarization-fluency",
- "data2text-naturalness",
- "data2text-informativeness"
- ]
- },
- "chat": {
- "GPT": [
- "language organization",
- "naturalness",
- "engagingness",
- "fidelity"
- ],
- "Metrics": [
- "Distinct"
- ],
- "UniEval": [
- "summarization-fluency",
- "dialogue-naturalness",
- "dialogue-coherence",
- "dialogue-understandability",
- "data2text-naturalness",
- "data2text-informativeness"
- ]
- },
- "classification": {
- "GPT": [
- "relevance",
- "correctness"
- ],
- "Metrics": [
- "Precision",
- "Recall",
- "F1 score",
- "CHRF"
- ],
- "UniEval": [
- "summarization-fluency",
- "data2text-naturalness",
- "data2text-informativeness"
- ]
- },
- "closed_qa": {
- "GPT": [
- "relevance",
- "correctness"
- ],
- "Metrics": [
- "BLEU",
- "ROUGE",
- "BERTScore",
- "CHRF"
- ],
- "UniEval": [
- "summarization-fluency",
- "data2text-naturalness",
- "data2text-informativeness"
- ]
- },
- "extraction": {
- "GPT": [
- "relevance",
- "correctness"
- ],
- "Metrics": [
- "Precision",
- "Recall",
- "F1 score",
- "CHRF"
- ],
- "UniEval": [
- "summarization-fluency",
- "data2text-naturalness",
- "data2text-informativeness"
- ]
- },
- "generation": {
- "GPT": [
- "language organization",
- "relevance",
- "diversity"
- ],
- "Metrics": [
- "BLEU",
- "ROUGE",
- "BERTScore"
- ],
- "UniEval": [
- "summarization-fluency",
- "data2text-naturalness",
- "data2text-informativeness"
- ]
- },
- "logical_reasoning": {
- "GPT": [
- "correctness",
- "relevance",
- "reasonableness"
- ],
- "Metrics": [
- "BLEU",
- "ROUGE",
- "BERTScore",
- "CHRF"
- ],
- "UniEval": [
- ]
- },
- "open_qa": {
- "GPT": [
- "language organization",
- "relevance",
- "correctness"
- ],
- "Metrics": [
- "Distinct"
- ],
- "UniEval": [
- "summarization-fluency",
- "data2text-naturalness",
- "data2text-informativeness"
- ]
- },
- "rewriting": {
- "GPT": [
- "language organization",
- "relevance",
- "correctness"
- ],
- "Metrics": [
- "BLEU",
- "ROUGE",
- "BERTScore"
- ],
- "UniEval": [
- "summarization-fluency",
- "data2text-naturalness",
- "data2text-informativeness"
- ]
- },
- "roleplay": {
- "GPT": [
- "language organization",
- "relevance",
- "fidelity",
- "creativity"
- ],
- "Metrics": [
- "Distinct"
- ],
- "UniEval": [
- "summarization-fluency",
- "data2text-naturalness",
- "data2text-informativeness"
- ]
- },
- "summarization": {
- "GPT": [
- "language organization",
- "relevance",
- "correctness",
- "conciseness"
- ],
- "Metrics": [
- "BLEU",
- "ROUGE",
- "BERTScore",
- "CHRF"
- ],
- "UniEval": [
- ]
- },
- "Finance": {
- "GPT": [
- "relevance",
- "correctness"
- ],
- "Metrics": [
- ],
- "UniEval": [
- ]
- },
- "Law": {
- "GPT": [
- "relevance",
- "correctness"
- ],
- "Metrics": [
- ],
- "UniEval": [
- ]
- },
- "Education": {
- "GPT": [
- "relevance",
- "correctness"
- ],
- "Metrics": [
- ],
- "UniEval": [
- ]
- },
- "Medical": {
- "GPT": [
- "relevance",
- "correctness"
- ],
- "Metrics": [
- ],
- "UniEval": [
- ]
- },
- "STEM": {
- "GPT": [
- "relevance",
- "correctness"
- ],
- "Metrics": [
- ],
- "UniEval": [
- ]
- },
- "SocialScience": {
- "GPT": [
- "relevance",
- "correctness"
- ],
- "Metrics": [
- ],
- "UniEval": [
- ]
- },
- "Humanity": {
- "GPT": [
- "relevance",
- "correctness"
- ],
- "Metrics": [
- ],
- "UniEval": [
- ]
- },
- "Other": {
- "GPT": [
- "relevance",
- "correctness"
- ],
- "Metrics": [
- ],
- "UniEval": [
- ]
- },
- "ethics": {
- "GPT": [
- "relevance",
- "correctness"
- ],
- "Metrics": [
- ],
- "UniEval": [
- ]
- }
- }
-}
diff --git a/applications/Chat/evaluate/eval.py b/applications/Chat/evaluate/eval.py
deleted file mode 100644
index e3fe0e9e091b..000000000000
--- a/applications/Chat/evaluate/eval.py
+++ /dev/null
@@ -1,112 +0,0 @@
-import argparse
-import json
-import os
-
-import openai
-from evaluator import Evaluator
-from utils import jload
-
-
-def main(args):
- assert len(args.answer_file_list) == len(
- args.model_name_list), "The number of answer files and model names should be equal!"
-
- # load config
- config = jload(args.config_file)
-
- if config["language"] in ["cn", "en"]:
- # get metric settings for all categories
- metrics_per_category = {}
- for category in config["category"].keys():
- metrics_all = {}
- for metric_type, metrics in config["category"][category].items():
- metrics_all[metric_type] = metrics
- metrics_per_category[category] = metrics_all
-
- battle_prompt = None
- if args.battle_prompt_file:
- battle_prompt = jload(args.battle_prompt_file)
-
- gpt_evaluation_prompt = None
- if args.gpt_evaluation_prompt_file:
- gpt_evaluation_prompt = jload(args.gpt_evaluation_prompt_file)
-
- if len(args.model_name_list) == 2 and not battle_prompt:
- raise Exception("No prompt file for battle provided. Please specify the prompt file for battle!")
-
- if len(args.model_name_list) == 1 and not gpt_evaluation_prompt:
- raise Exception(
- "No prompt file for gpt evaluation provided. Please specify the prompt file for gpt evaluation!")
-
- if args.gpt_model == "text-davinci-003" and args.gpt_with_reference:
- raise Exception(
- "GPT evaluation with reference is not supported for text-davinci-003. You should specify chat models such as gpt-3.5-turbo or gpt-4."
- )
-
- # initialize evaluator
- evaluator = Evaluator(metrics_per_category, battle_prompt, gpt_evaluation_prompt, args.gpt_model,
- config["language"], config.get("path_for_UniEval", None), args.gpt_with_reference)
- if len(args.model_name_list) == 2:
- answers1 = jload(args.answer_file_list[0])
- answers2 = jload(args.answer_file_list[1])
-
- assert len(answers1) == len(answers2), "The number of answers for two models should be equal!"
-
- evaluator.battle(answers1=answers1, answers2=answers2)
- evaluator.save(args.save_path, args.model_name_list)
- elif len(args.model_name_list) == 1:
- targets = jload(args.target_file)
- answers = jload(args.answer_file_list[0])
-
- assert len(targets) == len(answers), "The number of target answers and model answers should be equal!"
-
- evaluator.evaluate(answers=answers, targets=targets)
- evaluator.save(args.save_path, args.model_name_list)
- else:
- raise ValueError("Unsupported number of answer files and model names!")
- else:
- raise ValueError(f'Unsupported language {config["language"]}!')
-
-
-if __name__ == '__main__':
- parser = argparse.ArgumentParser(description='ColossalAI LLM evaluation pipeline.')
- parser.add_argument('--config_file',
- type=str,
- default=None,
- required=True,
- help='path to the file of target results')
- parser.add_argument('--battle_prompt_file', type=str, default=None, help='path to the prompt file for battle')
- parser.add_argument('--gpt_evaluation_prompt_file',
- type=str,
- default=None,
- help='path to the prompt file for gpt evaluation')
- parser.add_argument('--target_file', type=str, default=None, help='path to the target answer (ground truth) file')
- parser.add_argument('--answer_file_list',
- type=str,
- nargs='+',
- default=[],
- required=True,
- help='path to the answer files of at most 2 models')
- parser.add_argument('--model_name_list',
- type=str,
- nargs='+',
- default=[],
- required=True,
- help='the names of at most 2 models')
- parser.add_argument('--gpt_model',
- default="gpt-3.5-turbo",
- choices=["text-davinci-003", "gpt-3.5-turbo", "gpt-4"],
- help='which GPT model to use for evaluation')
- parser.add_argument('--gpt_with_reference',
- default=False,
- action="store_true",
- help='whether to include reference answer in gpt evaluation')
- parser.add_argument('--save_path', type=str, default="results", help='path to save evaluation results')
- parser.add_argument('--openai_key', type=str, default=None, required=True, help='Your openai key')
- args = parser.parse_args()
-
- if args.openai_key is not None:
- os.environ["OPENAI_API_KEY"] = args.openai_key
- openai.api_key = os.getenv("OPENAI_API_KEY")
-
- main(args)
diff --git a/applications/Chat/evaluate/evaluator.py b/applications/Chat/evaluate/evaluator.py
deleted file mode 100644
index 3dd5fd6f2f23..000000000000
--- a/applications/Chat/evaluate/evaluator.py
+++ /dev/null
@@ -1,219 +0,0 @@
-import os
-from typing import Any, Dict, List
-
-import gpt_evaluate
-import metrics
-import pandas as pd
-import unieval
-from utils import analyze_automatic_results, get_data_per_category, save_automatic_results
-
-
-class Evaluator(object):
- """
- A class named Evaluator includes GPT-3.5/GPT-4 evaluation
- and automatic evaluation
-
- """
-
- def __init__(self, params: Dict[str, Any], battle_prompt: Dict[str, Any], gpt_evaluation_prompt: Dict[str, Any],
- gpt_model: str, language: str, path_for_UniEval: Dict[str, str], gpt_with_reference: bool) -> None:
- self.params = params
- self.battle_prompt = battle_prompt
- self.gpt_evaluation_prompt = gpt_evaluation_prompt
- self.gpt_model = gpt_model
- self.language = language
- self.path_for_UniEval = path_for_UniEval
- self.gpt_with_reference = gpt_with_reference
- self.automatic_metric_stats = dict()
- self.unieval_metric_stats = dict()
- self.gpt_evaluation_results = dict()
- self.battle_results = []
-
- def battle(self, answers1: List[Dict], answers2: List[Dict]) -> None:
- """
- Comparison between two models using GPT-4 as the reviewer.
- """
-
- self.battle_results = gpt_evaluate.battle(answers1, answers2, self.battle_prompt)
-
- def evaluate(self, answers: List[Dict], targets: List[Dict]) -> None:
- """
- A comprehensive evaluation of the answers from the model.
- The function evaluates the model's performance from different perspectives
- using GPT-3.5, GPT-4, and off-the-shelf evaluation metrics.
-
- The metrics will be decided by the config file.
-
- """
-
- def switch(metric, language):
- if metric == "BLEU":
- return metrics.bleu_score(preds=predicts_list, targets=targets_list, language=language)
- elif metric == "ROUGE":
- return metrics.rouge_score(preds=predicts_list, targets=targets_list, language=language)
- elif metric == "Distinct":
- return metrics.distinct_score(preds=predicts_list, language=language)
- elif metric == "BERTScore":
- return metrics.bert_score(preds=predicts_list, targets=targets_list, language=language)
- elif metric == "Precision":
- return metrics.precision(preds=predicts_list, targets=targets_list, language=language)
- elif metric == "Recall":
- return metrics.recall(preds=predicts_list, targets=targets_list, language=language)
- elif metric == "F1 score":
- return metrics.F1_score(preds=predicts_list, targets=targets_list, language=language)
- elif metric == "CHRF":
- return metrics.chrf_score(preds=predicts_list, targets=targets_list, language=language)
- else:
- raise ValueError(f"Unexpected metric")
-
- answers_per_category = get_data_per_category(answers, list(self.params.keys()))
- targets_per_category = get_data_per_category(targets, list(self.params.keys()))
-
- # automatic evaluation
- for category in self.params:
- if len(answers_per_category[category]) == 0:
- print(f"Category {category} specified in your config doesn't have corresponding answers!")
- continue
-
- if self.params[category].get("Metrics", None) is None:
- continue
-
- category_metrics = self.params[category]["Metrics"]
- self.automatic_metric_stats[category] = {}
-
- targets_list = [
- target["target"] if target["target"] else target["output"] for target in targets_per_category[category]
- ]
- predicts_list = [answer["output"] for answer in answers_per_category[category]]
-
- for metric in category_metrics:
- self.automatic_metric_stats[category].update(switch(metric=metric, language=self.language))
-
- # UniEval evaluation
- # self.unieval_metric_stats's key is "task" instead of "category".
- # Iterating "task" first will avoid repeated loading models because one task corresponds to one UniEval model.
- # If key is "category", different models will be loaded for multiple times across categories because the user may require different task(models) to evaluate one category.
- for category in self.params:
- if len(answers_per_category[category]) == 0:
- print(f"Category {category} specified in your config doesn't have corresponding answers!")
- continue
-
- if self.params[category].get("UniEval", None) is None:
- continue
-
- if self.params[category]["UniEval"] and self.language == "cn":
- raise Exception(
- "UniEval doesn't support Chinese! Please remove UniEval config in your Chinese config file.")
-
- category_metrics = self.params[category]["UniEval"]
-
- for task, metric in [tuple(category_metric.split("-")) for category_metric in category_metrics]:
- if self.unieval_metric_stats.get(task, None) is None:
- self.unieval_metric_stats[task] = {category: {metric: 0}}
- elif self.unieval_metric_stats[task].get(category, None) is None:
- self.unieval_metric_stats[task][category] = {metric: 0}
- else:
- self.unieval_metric_stats[task][category][metric] = 0
-
- for task in self.unieval_metric_stats:
- if self.path_for_UniEval is None:
- raise Exception(f"Please specify the path for UniEval model in the config file!")
-
- if self.path_for_UniEval.get(task, None) is None:
- raise Exception(f"Please specify the model path for task {task} in the config file!")
-
- print(f"Load UniEval model for task {task}.")
-
- uni_evaluator = unieval.get_evaluator(task, model_name_or_path=self.path_for_UniEval[task])
- for category in self.unieval_metric_stats[task]:
- targets_list = [
- target["target"] if target["target"] else target["output"]
- for target in targets_per_category[category]
- ]
- predicts_list = [answer["output"] for answer in answers_per_category[category]]
- sources_list = [answer["instruction"] + answer["input"] for answer in answers_per_category[category]]
-
- data = unieval.convert_data_to_unieval_format(predicts_list, sources_list, targets_list)
- scores = uni_evaluator.evaluate(data,
- category,
- dims=list(self.unieval_metric_stats[task][category].keys()),
- overall=False)
- avg_scores = unieval.calculate_average_score(scores)
-
- self.unieval_metric_stats[task][category].update(avg_scores)
-
- # gpt evaluation
- for category in self.params:
- if len(answers_per_category[category]) == 0:
- print(f"Category {category} specified in your config doesn't have corresponding answers!")
- continue
-
- if self.params[category].get("GPT", None) is None:
- continue
-
- category_metrics = self.params[category]["GPT"]
-
- prompt = self.gpt_evaluation_prompt.get(category, None)
- if prompt is None:
- print(f"No prompt for category {category}! Use prompt for category general now.")
- prompt = self.gpt_evaluation_prompt["general"]
-
- self.gpt_evaluation_results[category] = gpt_evaluate.evaluate(
- answers_per_category[category],
- prompt,
- category_metrics,
- category,
- self.gpt_model,
- self.language,
- references=targets_per_category[category] if self.gpt_with_reference else None)
-
- def save(self, path: str, model_name_list: List[str]) -> None:
- """
- Save evaluation results of GPT-3.5, GPT-4, and off-the-shelf evaluation metrics.
-
- """
-
- if len(model_name_list) == 2:
- save_path = os.path.join(path, "gpt_evaluate", "battle_results")
- gpt_evaluate.save_battle_results(self.battle_results, model_name_list[0], model_name_list[1], save_path)
- else:
- if self.automatic_metric_stats:
- # Save evaluation results for automatic metrics
- automatic_base_save_path = os.path.join(path, "automatic_results")
- automatic_results_save_path = os.path.join(automatic_base_save_path, "evaluation_results")
-
- save_automatic_results(model_name_list[0], self.automatic_metric_stats, automatic_results_save_path)
-
- # Save charts and csv.
- automatic_analyses_save_path = os.path.join(automatic_base_save_path, "evaluation_analyses")
- analyze_automatic_results(automatic_results_save_path, automatic_analyses_save_path)
-
- if self.unieval_metric_stats:
- # Save evaluation results for UniEval metrics
- unieval_base_save_path = os.path.join(path, "unieval_results")
- unieval_results_save_path = os.path.join(unieval_base_save_path, "evaluation_results")
-
- unieval.save_unieval_results(model_name_list[0], self.unieval_metric_stats, unieval_results_save_path)
-
- # Save charts and csv.
- unieval_analyses_save_path = os.path.join(unieval_base_save_path, "evaluation_analyses")
- unieval.analyze_unieval_results(unieval_results_save_path, unieval_analyses_save_path)
-
- if self.gpt_evaluation_results:
- # Save evaluation results for GPT evaluation metrics.
- gpt_base_save_path = os.path.join(path, "gpt_evaluate", "gpt_evaluate_results")
- gpt_evaluation_results_save_path = os.path.join(gpt_base_save_path, "evaluation_results")
-
- all_evaluations = gpt_evaluate.save_gpt_evaluation_results(model_name_list[0],
- self.gpt_evaluation_results,
- gpt_evaluation_results_save_path)
-
- # Start to calculate scores and save statistics.
- gpt_evaluation_statistics_save_path = os.path.join(gpt_base_save_path, "evaluation_statistics")
- gpt_evaluate.save_gpt_evaluation_statistics(model_name_list[0], all_evaluations,
- gpt_evaluation_statistics_save_path)
-
- # Save charts and csv.
- gpt_evaluation_analyses_save_path = os.path.join(gpt_base_save_path, "evaluation_analyses")
- gpt_evaluate.analyze_gpt_evaluation_statistics(gpt_evaluation_statistics_save_path,
- gpt_evaluation_analyses_save_path)
diff --git a/applications/Chat/evaluate/metrics.py b/applications/Chat/evaluate/metrics.py
deleted file mode 100644
index 77f9b6e98044..000000000000
--- a/applications/Chat/evaluate/metrics.py
+++ /dev/null
@@ -1,253 +0,0 @@
-import statistics
-from typing import Dict, List
-
-import jieba
-from bert_score import score
-from nltk.translate.bleu_score import sentence_bleu
-from nltk.translate.chrf_score import sentence_chrf
-from rouge_chinese import Rouge as Rouge_cn
-from rouge_score import rouge_scorer as Rouge_en
-from sklearn.metrics import f1_score, precision_score, recall_score
-from utils import preprocessing_text, remove_redundant_space
-
-
-def bleu_score(preds: List[str], targets: List[str], language: str) -> Dict[str, float]:
- """Calculate BLEU Score Metric
-
- The calculation includes BLEU-1 for unigram, BLEU-2 for bigram,
- BLEU-3 for trigram and BLEU-4 for 4-gram. Unigram evaluates the
- accuracy in word level, other n-gram evaluate the fluency in
- sentence level.
- """
- bleu_scores = {"bleu1": 0, "bleu2": 0, "bleu3": 0, "bleu4": 0}
- cumulative_bleu = [0] * 4
- weights = [(1. / 1., 0., 0., 0.), (1. / 2., 1. / 2., 0., 0.), (1. / 3., 1. / 3., 1. / 3., 0.),
- (1. / 4., 1. / 4., 1. / 4., 1. / 4.)]
-
- for pred, target in zip(preds, targets):
- if language == "cn":
- pred_list = ' '.join(jieba.cut(preprocessing_text(pred))).split()
- target_list = [(' '.join(jieba.cut(preprocessing_text(target)))).split()]
- elif language == "en":
- pred_list = preprocessing_text(pred).split()
- target_list = [preprocessing_text(target).split()]
-
- bleu = sentence_bleu(target_list, pred_list, weights=weights)
- cumulative_bleu = [a + b for a, b in zip(cumulative_bleu, bleu)]
-
- for i in range(len(cumulative_bleu)):
- bleu_scores[f"bleu{i+1}"] = cumulative_bleu[i] / len(preds)
-
- return bleu_scores
-
-
-def chrf_score(preds: List[str], targets: List[str], language: str) -> Dict[str, float]:
- """Calculate CHRF Score Metric in sentence level.
- """
- chrf_score = {"chrf": 0}
- cumulative_chrf = []
-
- for pred, target in zip(preds, targets):
- if language == "cn":
- pred_list = ' '.join(jieba.cut(preprocessing_text(pred))).split()
- target_list = ' '.join(jieba.cut(preprocessing_text(target))).split()
- elif language == "en":
- pred_list = preprocessing_text(pred).split()
- target_list = preprocessing_text(target).split()
-
- cumulative_chrf.append(sentence_chrf(target_list, pred_list))
-
- chrf_score["chrf"] = statistics.mean(cumulative_chrf)
-
- return chrf_score
-
-
-def rouge_cn_score(preds: List[str], targets: List[str]) -> Dict[str, float]:
- """Calculate Chinese ROUGE Score Metric
-
- The calculation includes ROUGE-1 for unigram, ROUGE-2 for bigram
- and ROUGE-L. ROUGE-N evaluates the number of matching n-grams between
- the preds and targets. ROUGE-L measures the number of matching
- longest common subsequence (LCS) between preds and targets.
- """
- rouge_scores = {"rouge1": 0, "rouge2": 0, "rougeL": 0}
- all_preds = []
- all_targets = []
-
- for pred, target in zip(preds, targets):
- pred_list = remove_redundant_space(' '.join(jieba.cut(preprocessing_text(pred))))
- target_list = remove_redundant_space(' '.join(jieba.cut(preprocessing_text(target))))
- all_preds.append(pred_list)
- all_targets.append(target_list)
-
- rouge_cn = Rouge_cn()
- rouge_avg = rouge_cn.get_scores(all_preds, all_targets, avg=True)
-
- rouge_scores["rouge1"] = rouge_avg["rouge-1"]["f"]
- rouge_scores["rouge2"] = rouge_avg["rouge-2"]["f"]
- rouge_scores["rougeL"] = rouge_avg["rouge-l"]["f"]
-
- return rouge_scores
-
-
-def rouge_en_score(preds: List[str], targets: List[str]) -> Dict[str, float]:
- """Calculate English ROUGE Score Metric
-
- The calculation includes ROUGE-1 for unigram, ROUGE-2 for bigram
- and ROUGE-L. ROUGE-N evaluates the number of matching n-grams between
- the preds and targets. ROUGE-L measures the number of matching
- longest common subsequence (LCS) between preds and targets.
- """
- rouge_scores = {"rouge1": 0, "rouge2": 0, "rougeL": 0}
- all_preds = []
- all_targets = []
-
- rouge_en = Rouge_en.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=False)
-
- for pred, target in zip(preds, targets):
- score = rouge_en.score(preprocessing_text(pred), preprocessing_text(target))
- rouge_scores["rouge1"] += score['rouge1'].fmeasure
- rouge_scores["rouge2"] += score['rouge2'].fmeasure
- rouge_scores["rougeL"] += score['rougeL'].fmeasure
-
- rouge_scores["rouge1"] = rouge_scores["rouge1"] / len(preds)
- rouge_scores["rouge2"] = rouge_scores["rouge2"] / len(preds)
- rouge_scores["rougeL"] = rouge_scores["rougeL"] / len(preds)
-
- return rouge_scores
-
-
-def rouge_score(preds: List[str], targets: List[str], language: str) -> Dict[str, float]:
- """Calculate ROUGE Score Metric"""
- if language == "cn":
- return rouge_cn_score(preds, targets)
- elif language == "en":
- return rouge_en_score(preds, targets)
-
-
-def distinct_score(preds: List[str], language: str) -> Dict[str, float]:
- """Calculate Distinct Score Metric
-
- This metric refers to https://arxiv.org/abs/1510.03055.
- It evaluates the diversity of generation text by counting
- the unique n-grams.
- """
- distinct_score = {"distinct": 0}
- cumulative_distinct = []
-
- for pred in preds:
- if language == "cn":
- pred_seg_list = ' '.join(jieba.cut(pred)).split()
- count_segs = len(pred_seg_list)
- unique_segs = set(pred_seg_list)
- count_unique_chars = len(unique_segs)
- # prevent denominator from being 0
- cumulative_distinct.append(count_unique_chars / (count_segs + 1e-6))
- elif language == "en":
- # calculate distinct 1-gram, 2-gram, 3-gram
- unique_ngram = [set() for _ in range(0, 3)]
- all_ngram_count = [0 for _ in range(0, 3)]
-
- split_pred = preprocessing_text(pred).split()
- for n in range(0, 3):
- for i in range(0, len(split_pred) - n):
- ngram = ' '.join(split_pred[i:i + n + 1])
- unique_ngram[n].add(ngram)
- all_ngram_count[n] += 1
-
- # Sometimes the answer may contain only one word. For 2-gram and 3-gram, the gram count(denominator) may be zero.
- avg_distinct = [len(a) / (b + 1e-6) for a, b in zip(unique_ngram, all_ngram_count)]
-
- cumulative_distinct.append(statistics.mean(avg_distinct))
-
- distinct_score["distinct"] = statistics.mean(cumulative_distinct)
-
- return distinct_score
-
-
-def bert_score(preds: List[str], targets: List[str], language: str) -> Dict[str, float]:
- """Calculate BERTScore Metric
-
- The BERTScore evaluates the semantic similarity between
- tokens of preds and targets with BERT.
- """
- bert_score = {"bert_score": 0}
- pred_list = []
- target_list = []
-
- for pred, target in zip(preds, targets):
- pred_list.append(pred)
- target_list.append(target)
-
- if language == "cn":
- _, _, F = score(pred_list, target_list, lang="zh", verbose=True)
- elif language == "en":
- _, _, F = score(pred_list, target_list, lang="en", verbose=True)
-
- bert_score["bert_score"] = F.mean().item()
-
- return bert_score
-
-
-def calculate_precision_recall_f1(preds: List[str], targets: List[str], language: str) -> Dict[str, float]:
- """Precision, Recall and F1-Score Calculation
-
- The calculation of precision, recall and f1-score is realized by counting
- the number f overlaps between the preds and target. The comparison length
- limited by the shorter one of preds and targets.
- """
- precision_recall_f1 = {"precision": 0, "recall": 0, "f1_score": 0}
- precision_scores = []
- recall_scores = []
- f1_scores = []
-
- for pred, target in zip(preds, targets):
- if language == "cn":
- pred_list = [char for char in ' '.join(jieba.cut(preprocessing_text(pred))).split()]
- target_list = [char for char in ' '.join(jieba.cut(preprocessing_text(target))).split()]
- elif language == "en":
- pred_list = [char for char in preprocessing_text(pred).split()]
- target_list = [char for char in preprocessing_text(target).split()]
-
- target_labels = [1] * min(len(target_list), len(pred_list))
- pred_labels = [int(pred_list[i] == target_list[i]) for i in range(0, min(len(target_list), len(pred_list)))]
-
- precision_scores.append(precision_score(target_labels, pred_labels, zero_division=0))
- recall_scores.append(recall_score(target_labels, pred_labels, zero_division=0))
- f1_scores.append(f1_score(target_labels, pred_labels, zero_division=0))
-
- precision_recall_f1["precision"] = statistics.mean(precision_scores)
- precision_recall_f1["recall"] = statistics.mean(recall_scores)
- precision_recall_f1["f1_score"] = statistics.mean(f1_scores)
-
- return precision_recall_f1
-
-
-def precision(preds: List[str], targets: List[str], language: str) -> Dict[str, float]:
- """Calculate Precision Metric
-
- Calculating precision by counting the number of overlaps between the preds and target.
- """
- precision = {"precision": 0}
- precision["precision"] = calculate_precision_recall_f1(preds, targets, language)["precision"]
- return precision
-
-
-def recall(preds: List[str], targets: List[str], language: str) -> Dict[str, float]:
- """Calculate Recall Metric
-
- Calculating recall by counting the number of overlaps between the preds and target.
- """
- recall = {"recall": 0}
- recall["recall"] = calculate_precision_recall_f1(preds, targets, language)["recall"]
- return recall
-
-
-def F1_score(preds: List[str], targets: List[str], language: str) -> Dict[str, float]:
- """Calculate F1-score Metric
-
- Calculating f1-score by counting the number of overlaps between the preds and target.
- """
- f1 = {"f1_score": 0}
- f1["f1_score"] = calculate_precision_recall_f1(preds, targets, language)["f1_score"]
- return f1
diff --git a/applications/Chat/evaluate/requirements.txt b/applications/Chat/evaluate/requirements.txt
deleted file mode 100644
index 27d317ed88cc..000000000000
--- a/applications/Chat/evaluate/requirements.txt
+++ /dev/null
@@ -1,12 +0,0 @@
-jieba
-bert-score
-rouge_chinese
-scikit-metrics
-nltk
-openai
-seaborn
-pandas
-matplotlib
-numpy
-zhon
-rouge_score
diff --git a/applications/Chat/evaluate/unieval/__init__.py b/applications/Chat/evaluate/unieval/__init__.py
deleted file mode 100644
index dad8d6ad09fa..000000000000
--- a/applications/Chat/evaluate/unieval/__init__.py
+++ /dev/null
@@ -1,12 +0,0 @@
-from .evaluator import get_evaluator
-from .utils import (
- analyze_unieval_results,
- calculate_average_score,
- convert_data_to_unieval_format,
- save_unieval_results,
-)
-
-__all__ = [
- 'get_evaluator', 'convert_data_to_unieval_format', 'calculate_average_score', 'save_unieval_results',
- 'analyze_unieval_results'
-]
diff --git a/applications/Chat/evaluate/unieval/evaluator.py b/applications/Chat/evaluate/unieval/evaluator.py
deleted file mode 100644
index 56cc6d2f9e41..000000000000
--- a/applications/Chat/evaluate/unieval/evaluator.py
+++ /dev/null
@@ -1,331 +0,0 @@
-# MIT License
-
-# Copyright (c) 2022 Ming Zhong
-
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-
-# The above copyright notice and this permission notice shall be included in all
-# copies or substantial portions of the Software.
-
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-# SOFTWARE.
-
-import numpy as np
-from nltk import sent_tokenize
-
-from .scorer import UniEvaluator
-from .utils import add_question
-
-
-class SumEvaluator:
-
- def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None):
- """ Set up evaluator for text summarization """
- self.scorer = UniEvaluator(
- model_name_or_path='MingZhong/unieval-sum' if model_name_or_path == "" else model_name_or_path,
- max_length=max_length,
- device=device,
- cache_dir=cache_dir)
- self.task = 'summarization'
- self.dimensions = ['coherence', 'consistency', 'fluency', 'relevance']
-
- def evaluate(self, data, category, dims=None, overall=True):
- """
- Get the scores of all the given dimensions
-
- category: The category to be evaluated.
-
- dims: A list of dimensions to be evaluated. If dims is None, SumEvaluator will evaluate
- four dimensions: coherence, consistency, fluency, relevance.
-
- overall: indicates whether the overall score is to be calculated.
- Overall score can be customized to a combination of scores based on different
- dimensions. The default here is the average score of all the given dimensions.
- """
- n_data = len(data)
- eval_scores = [{} for _ in range(n_data)]
-
- if dims == None:
- eval_dims = self.dimensions
- else:
- assert isinstance(dims, list)
- eval_dims = dims
-
- for dim in eval_dims:
- # Calculate average sentence-level scores for 'consistency' and 'fluency'
- if dim == 'consistency' or dim == 'fluency':
- src_list, output_list = [], []
- n_sents = [] # the number of sentences in each generated summary
- for i in range(n_data):
- source = data[i]['source']
- system_outputs = sent_tokenize(data[i]['system_output'])
- n_sents.append(len(system_outputs))
- for j in range(len(system_outputs)):
- src_list.append(source)
- output_list.append(system_outputs[j])
- input_list = add_question(dimension=dim, output=output_list, src=src_list, task=self.task)
- sent_score = self.scorer.score(input_list, self.task, category, dim)
-
- # Get average score for each sample
- start_idx = 0
- score = []
- for cur_n_sent in n_sents:
- # prevent denominator from being 0
- score.append(sum(sent_score[start_idx:start_idx + cur_n_sent]) / (cur_n_sent + 1e-6))
- start_idx += cur_n_sent
-
- # Calculate summary-level score for 'coherence' and 'relevance'
- elif dim == 'coherence' or dim == 'relevance':
- src_list, output_list, ref_list = [], [], []
- for i in range(n_data):
- src_list.append(data[i]['source'])
- output_list.append(data[i]['system_output'])
- if dim == 'relevance':
- ref_list.append(data[i]['reference'])
- input_list = add_question(dimension=dim, output=output_list, src=src_list, ref=ref_list, task=self.task)
- score = self.scorer.score(input_list, self.task, category, dim)
-
- # Please customize other dimensions here for summarization
- else:
- raise NotImplementedError('The input format for this dimension is still undefined. \
- Please customize it first.')
-
- for i in range(n_data):
- eval_scores[i][dim] = score[i]
-
- # Customize your overall score here.
- if overall == True:
- for i in range(n_data):
- eval_scores[i]['overall'] = np.mean(list(eval_scores[i].values()))
-
- return eval_scores
-
-
-class DialogEvaluator:
-
- def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None):
- """ Set up evaluator for dialogues """
- self.scorer = UniEvaluator(
- model_name_or_path='MingZhong/unieval-dialog' if model_name_or_path == "" else model_name_or_path,
- max_length=max_length,
- device=device,
- cache_dir=cache_dir)
- self.task = 'dialogue'
- self.dimensions = ['naturalness', 'coherence', 'engagingness', 'groundedness', 'understandability']
-
- def evaluate(self, data, category, dims=None, overall=True):
- """
- Get the scores of all the given dimensions
-
- category: The category to be evaluated.
-
- dims: A list of dimensions to be evaluated. If dims is None, DialogEvaluator will evaluate
- five dimensions: naturalness, coherence, engagingness, groundedness and understandability.
-
- overall: indicates whether the overall score is to be calculated.
- Overall score can be customized to a combination of scores based on different
- dimensions. The default here is the average score of all the given dimensions.
- """
- n_data = len(data)
- eval_scores = [{} for _ in range(n_data)]
-
- if dims == None:
- eval_dims = self.dimensions
- else:
- assert isinstance(dims, list)
- eval_dims = dims
-
- for dim in eval_dims:
- # Calculate summation score for 'engagingness'
- if dim == 'engagingness':
- src_list, output_list, context_list = [], [], []
- n_sents = [] # the number of sentences in each generated response
- for i in range(n_data):
- source = data[i]['source']
- context = data[i]['context']
- system_outputs = sent_tokenize(data[i]['system_output'])
- n_sents.append(len(system_outputs))
- for j in range(len(system_outputs)):
- src_list.append(source)
- context_list.append(context)
- output_list.append(system_outputs[j])
- input_list = add_question(dimension=dim,
- output=output_list,
- src=src_list,
- context=context_list,
- task=self.task)
- sent_score = self.scorer.score(input_list, self.task, category, dim)
-
- # Get the summation score for each sample
- start_idx = 0
- score = []
- for cur_n_sent in n_sents:
- score.append(sum(sent_score[start_idx:start_idx + cur_n_sent]))
- start_idx += cur_n_sent
-
- # Calculate turn-level score for other dimensions
- elif dim in ['naturalness', 'coherence', 'groundedness', 'understandability']:
- src_list, output_list, context_list = [], [], []
- for i in range(n_data):
- src_list.append(data[i]['source'])
- output_list.append(data[i]['system_output'])
- context_list.append(data[i]['context'])
- input_list = add_question(dimension=dim,
- output=output_list,
- src=src_list,
- context=context_list,
- task=self.task)
- score = self.scorer.score(input_list, self.task, category, dim)
-
- # Please customize other dimensions here for summarization
- else:
- raise NotImplementedError('The input format for this dimension is still undefined. \
- Please customize it first.')
-
- for i in range(n_data):
- eval_scores[i][dim] = score[i]
-
- # Customize your overall score here.
- if overall == True:
- for i in range(n_data):
- eval_scores[i]['overall'] = np.mean(list(eval_scores[i].values()))
-
- return eval_scores
-
-
-class D2tEvaluator:
-
- def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None):
- """ Set up evaluator for data-to-text """
- self.scorer = UniEvaluator(
- model_name_or_path='MingZhong/unieval-sum' if model_name_or_path == "" else model_name_or_path,
- max_length=max_length,
- device=device,
- cache_dir=cache_dir)
- self.task = 'data2text'
- self.dimensions = ['naturalness', 'informativeness']
-
- def evaluate(self, data, category, dims=None, overall=True):
- """
- Get the scores of all the given dimensions
-
- category: The category to be evaluated.
-
- dims: A list of dimensions to be evaluated. If dims is None, D2tEvaluator will evaluate
- two dimensions: naturalness and informativeness.
-
- overall: indicates whether the overall score is to be calculated.
- Overall score can be customized to a combination of scores based on different
- dimensions. The default here is the average score of all the given dimensions.
- """
- n_data = len(data)
- eval_scores = [{} for _ in range(n_data)]
-
- if dims == None:
- eval_dims = self.dimensions
- else:
- assert isinstance(dims, list)
- eval_dims = dims
-
- for dim in eval_dims:
- output_list, ref_list = [], []
- for i in range(n_data):
- output_list.append(data[i]['system_output'])
- ref_list.append(data[i]['reference'])
-
- input_list = add_question(dimension=dim, output=output_list, ref=ref_list, task=self.task)
- score = self.scorer.score(input_list, self.task, category, dim)
-
- for i in range(n_data):
- eval_scores[i][dim] = score[i]
-
- # Customize your overall score here.
- if overall == True:
- for i in range(n_data):
- eval_scores[i]['overall'] = np.mean(list(eval_scores[i].values()))
-
- return eval_scores
-
-
-class FactEvaluator:
-
- def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None):
- """ Set up evaluator for factual consistency detection """
- self.scorer = UniEvaluator(
- model_name_or_path='MingZhong/unieval-fact' if model_name_or_path == "" else model_name_or_path,
- max_length=max_length,
- device=device,
- cache_dir=cache_dir)
- self.task = 'fact'
- self.dim = 'consistency'
-
- def evaluate(self, data, category):
- """
- Get the factual consistency score (only 1 dimension for this task)
-
- category: The category to be evaluated.
- """
- n_data = len(data)
- eval_scores = [{} for _ in range(n_data)]
-
- # Calculate average sentence-level scores for factual consistency
- src_list, output_list = [], []
- n_sents = [] # the number of sentences in the claim
- for i in range(n_data):
- source = data[i]['source']
- system_outputs = sent_tokenize(data[i]['system_output'])
- n_sents.append(len(system_outputs))
- for j in range(len(system_outputs)):
- src_list.append(source)
- output_list.append(system_outputs[j])
- input_list = add_question(dimension=self.dim, output=output_list, src=src_list, task=self.task)
- sent_score = self.scorer.score(input_list, self.task, category, self.dim)
-
- # Get average score for each sample
- start_idx = 0
- score = []
- for cur_n_sent in n_sents:
- score.append(sum(sent_score[start_idx:start_idx + cur_n_sent]) / cur_n_sent)
- start_idx += cur_n_sent
-
- for i in range(n_data):
- eval_scores[i][self.dim] = score[i]
-
- return eval_scores
-
-
-def get_evaluator(task, model_name_or_path="", max_length=1024, device='cuda:0', cache_dir=None):
- assert task in ['summarization', 'dialogue', 'data2text', 'fact']
- if task == 'summarization':
- return SumEvaluator(model_name_or_path=model_name_or_path,
- max_length=max_length,
- device=device,
- cache_dir=cache_dir)
- elif task == 'dialogue':
- return DialogEvaluator(model_name_or_path=model_name_or_path,
- max_length=max_length,
- device=device,
- cache_dir=cache_dir)
- elif task == 'data2text':
- return D2tEvaluator(model_name_or_path=model_name_or_path,
- max_length=max_length,
- device=device,
- cache_dir=cache_dir)
- elif task == 'fact':
- return FactEvaluator(model_name_or_path=model_name_or_path,
- max_length=max_length,
- device=device,
- cache_dir=cache_dir)
- else:
- raise NotImplementedError('Other tasks are not implemented, \
- please customize specific tasks here.')
diff --git a/applications/Chat/evaluate/unieval/scorer.py b/applications/Chat/evaluate/unieval/scorer.py
deleted file mode 100644
index 2c70bb9f6ded..000000000000
--- a/applications/Chat/evaluate/unieval/scorer.py
+++ /dev/null
@@ -1,101 +0,0 @@
-# MIT License
-
-# Copyright (c) 2022 Ming Zhong
-
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-
-# The above copyright notice and this permission notice shall be included in all
-# copies or substantial portions of the Software.
-
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-# SOFTWARE.
-
-import torch
-import torch.nn as nn
-from tqdm import tqdm
-from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer
-
-
-class UniEvaluator:
-
- def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None):
- """ Set up model """
- self.device = device
- self.max_length = max_length
-
- self.config = AutoConfig.from_pretrained(model_name_or_path, cache_dir=cache_dir)
- self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, cache_dir=cache_dir)
- self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, config=self.config, cache_dir=cache_dir)
-
- self.model.eval()
- self.model.to(device)
-
- self.softmax = nn.Softmax(dim=1)
-
- self.pos_id = self.tokenizer("Yes")["input_ids"][0]
- self.neg_id = self.tokenizer("No")["input_ids"][0]
-
- def score(self, inputs, task, category, dim, batch_size=8):
- """
- Get scores for the given samples.
- final_score = postive_score / (postive_score + negative_score)
- """
-
- # The implementation of "forward" in T5 still requires decoder_input_ids.
- # Therefore, we construct a random one-word target sequence.
- # The content of the target has no effect on the final scores.
- tgts = ["No" for _ in range(len(inputs))]
-
- pos_score_list, neg_score_list = [], []
- for i in tqdm(range(0, len(inputs), batch_size), desc=f"{category}-({dim}-{task}): "):
- src_list = inputs[i:i + batch_size]
- tgt_list = tgts[i:i + batch_size]
- try:
- with torch.no_grad():
- encoded_src = self.tokenizer(src_list,
- max_length=self.max_length,
- truncation=True,
- padding=True,
- return_tensors='pt')
- encoded_tgt = self.tokenizer(tgt_list,
- max_length=self.max_length,
- truncation=True,
- padding=True,
- return_tensors='pt')
-
- src_tokens = encoded_src['input_ids'].to(self.device)
- src_mask = encoded_src['attention_mask'].to(self.device)
-
- tgt_tokens = encoded_tgt['input_ids'].to(self.device)[:, 0].unsqueeze(-1)
-
- output = self.model(input_ids=src_tokens, attention_mask=src_mask, labels=tgt_tokens)
- logits = output.logits.view(-1, self.model.config.vocab_size)
-
- pos_score = self.softmax(logits)[:, self.pos_id] # Yes
- neg_score = self.softmax(logits)[:, self.neg_id] # No
-
- cur_pos_score = [x.item() for x in pos_score]
- cur_neg_score = [x.item() for x in neg_score]
- pos_score_list += cur_pos_score
- neg_score_list += cur_neg_score
-
- except RuntimeError:
- print(f'source: {src_list}')
- print(f'target: {tgt_list}')
- exit(0)
-
- score_list = []
- for i in range(len(pos_score_list)):
- score_list.append(pos_score_list[i] / (pos_score_list[i] + neg_score_list[i]))
-
- return score_list
diff --git a/applications/Chat/evaluate/unieval/utils.py b/applications/Chat/evaluate/unieval/utils.py
deleted file mode 100644
index a381e9e590b2..000000000000
--- a/applications/Chat/evaluate/unieval/utils.py
+++ /dev/null
@@ -1,248 +0,0 @@
-# MIT License
-
-# Copyright (c) 2022 Ming Zhong
-
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-
-# The above copyright notice and this permission notice shall be included in all
-# copies or substantial portions of the Software.
-
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-# SOFTWARE.
-
-import os
-from typing import Dict
-
-import matplotlib.pyplot as plt
-import pandas as pd
-import seaborn as sns
-import tqdm
-
-
-def add_question(dimension, output, src=None, ref=None, context=None, task=None):
- """
- Add questions to generate input in Bool-QA format for UniEval.
-
- dimension: specific dimension to be evaluated
- src: source input for different NLG tasks. For example, source document for summarization
- and dialogue history for dialogue response generation.
- output: output text generated by the models
- ref: human-annotated groundtruth
- context: the context needed to evaluate several specific dimension. For example,
- additional factual information when evaluating engagingness and groundedness in dialogues.
- """
-
- input_with_question = []
- for i in range(len(output)):
- # For summarization
- if task == 'summarization':
- if dimension == 'fluency':
- cur_input = 'question: Is this a fluent paragraph? paragraph: ' + output[i]
- elif dimension == 'coherence':
- cur_input = 'question: Is this a coherent summary to the document? summary: ' + output[
- i] + ' document: ' + src[i]
- elif dimension == 'consistency':
- cur_input = 'question: Is this claim consistent with the document? claim: ' + output[
- i] + ' document: ' + src[i]
- elif dimension == 'relevance':
- cur_input = 'question: Is this summary relevant to the reference? summary: ' + output[
- i] + ' reference: ' + ref[i]
- else:
- raise NotImplementedError(
- 'The input format for this dimension is still undefined. Please customize it first.')
- # For dialogues
- elif task == 'dialogue':
- if dimension == 'naturalness':
- cur_input = 'question: Is this a natural response in the dialogue? response: ' + output[i]
- elif dimension == 'coherence':
- cur_input = 'question: Is this a coherent response given the dialogue history? response: '\
- + output[i] + ' dialogue history: ' + src[i]
- elif dimension == 'engagingness':
- cur_input = 'question: Is this an engaging and informative response according to the dialogue history and fact? response: '\
- + output[i] + ' dialogue history: ' + src[i] + ' fact: ' + context[i]
- elif dimension == 'groundedness':
- cur_input = 'question: Is this response consistent with knowledge in the fact? response: '\
- + output[i] + ' fact: ' + context[i]
- elif dimension == 'understandability':
- cur_input = 'question: Is this an understandable response in the dialogue? response: ' + output[i]
- else:
- raise NotImplementedError(
- 'The input format for this dimension is still undefined. Please customize it first.')
- # For data-to-text
- elif task == 'data2text':
- if dimension == 'naturalness':
- cur_input = 'question: Is this a fluent utterance? utterance: ' + output[i]
- elif dimension == 'informativeness':
- cur_input = 'question: Is this sentence informative according to the reference? sentence: '\
- + output[i] + ' reference: ' + ref[i]
- else:
- raise NotImplementedError(
- 'The input format for this dimension is still undefined. Please customize it first.')
- # For factual consistency detection
- elif task == 'fact':
- if dimension == 'consistency':
- cur_input = 'question: Is this claim consistent with the document? claim: ' + output[
- i] + ' document: ' + src[i]
- else:
- raise NotImplementedError('No other dimensions for the factual consistency detection task.')
- # For new customized tasks
- else:
- raise NotImplementedError('Other tasks are not implemented, please customize specific tasks here.')
- input_with_question.append(cur_input)
- return input_with_question
-
-
-def convert_data_to_unieval_format(output_list, src_list=None, ref_list=None):
- """
- Convert the data into the unieval's format.
-
- output_list: a list of model output
-
- src_list: source input for different NLG tasks. For example, source document for summarization
- and dialogue history for dialogue response generation
- ref_list: human-annotated groundtruth
- """
- json_data = []
- for i in range(len(output_list)):
- cur = {}
- cur['system_output'] = output_list[i]
- if src_list is not None:
- cur['source'] = src_list[i]
- if ref_list is not None:
- cur['reference'] = ref_list[i]
- cur['context'] = ""
- json_data.append(cur)
- return json_data
-
-
-def calculate_average_score(scores):
- """
- Calculate average scores for different metrics
-
- scores: a list of scores for different metrics for each answer
-
- """
- metrics = {metric: 0 for metric in scores[0]}
-
- for score in scores:
- for metric in score:
- metrics[metric] += score[metric]
-
- for metric in metrics:
- metrics[metric] /= len(scores)
-
- return metrics
-
-
-def save_unieval_results(model_name: str, unieval_metric_stats: Dict[str, Dict], save_path: str) -> None:
- """
- Save UniEval evaluation results of different categories for one model.
-
- """
-
- if not os.path.exists(save_path):
- os.makedirs(save_path)
-
- unieval_metric_stats_per_category = {}
- for task, category_stat in unieval_metric_stats.items():
- for category, metric_stat in category_stat.items():
- if unieval_metric_stats_per_category.get(category, None) is None:
- unieval_metric_stats_per_category[category] = {}
- for metric, score in metric_stat.items():
- unieval_metric_stats_per_category[category][f"{metric}-{task}"] = score
-
- automatic_df = pd.DataFrame(unieval_metric_stats_per_category)
- automatic_df.to_csv(os.path.join(save_path, f"{model_name}_results.csv"), index=True)
-
-
-def read_unieval_results(results_path: str, file_name: str) -> Dict[str, Dict]:
- """
- Read a csv file and return a dictionary which stores scores per metric.
-
- """
-
- results = pd.read_csv(os.path.join(results_path, file_name), index_col=0)
-
- results_dict = {metric: {} for metric in list(results.index)}
- for i, metric in enumerate(results_dict.keys()):
- for j, category in enumerate(list(results.columns)):
- if pd.isnull(results.iloc[i][j]):
- continue
- results_dict[metric][category] = results.iloc[i][j]
-
- return results_dict
-
-
-def analyze_unieval_results(results_path: str, save_path: str) -> None:
- """
- Analyze and visualize all csv files in the given folder.
-
- """
-
- if not os.path.exists(results_path):
- raise Exception(f'The given directory "{results_path}" doesn\'t exist! No results found!')
-
- all_statistics = {}
-
- for file_name in os.listdir(results_path):
- if file_name.endswith("_results.csv"):
- model_name = file_name.split("_results.csv")[0]
- all_statistics[model_name] = read_unieval_results(results_path, file_name)
-
- if len(list(all_statistics.keys())) == 0:
- raise Exception(f'There are no csv files in the given directory "{results_path}"!')
-
- frame_all = {"model": [], "category": [], "metric": [], "score": []}
- frame_per_metric = {}
- for model_name, model_statistics in all_statistics.items():
- for metric, metric_statistics in model_statistics.items():
- if frame_per_metric.get(metric) is None:
- frame_per_metric[metric] = {"model": [], "category": [], "score": []}
-
- for category, category_score in metric_statistics.items():
- frame_all["model"].append(model_name)
- frame_all["category"].append(category)
- frame_all["metric"].append(metric)
- frame_all["score"].append(category_score)
-
- frame_per_metric[metric]["model"].append(model_name)
- frame_per_metric[metric]["category"].append(category)
- frame_per_metric[metric]["score"].append(category_score)
-
- if not os.path.exists(save_path):
- os.makedirs(save_path)
-
- frame_all = pd.DataFrame(frame_all)
- frame_all.to_csv(os.path.join(save_path, "unieval_statistics.csv"))
-
- for metric in tqdm.tqdm(
- frame_per_metric.keys(),
- desc=f"UniEval metrics: ",
- total=len(frame_per_metric.keys()),
- ):
- data = pd.DataFrame(frame_per_metric[metric])
-
- sns.set()
- fig = plt.figure(figsize=(16, 10))
-
- fig = sns.barplot(x="category", y="score", hue="model", data=data, dodge=True)
- fig.set_title(
- f"Comparison between Different Models for Metric {metric.split('-')[0].title()} in Task {metric.split('-')[1].title()}"
- )
- plt.xlabel("Evaluation Category")
- plt.ylabel("Score")
-
- figure = fig.get_figure()
- figure.savefig(os.path.join(save_path, f"{metric}.png"), dpi=400)
-
- plt.close()
diff --git a/applications/Chat/evaluate/utils.py b/applications/Chat/evaluate/utils.py
deleted file mode 100644
index 406e43db99aa..000000000000
--- a/applications/Chat/evaluate/utils.py
+++ /dev/null
@@ -1,207 +0,0 @@
-import io
-import json
-import os
-import re
-import string
-from typing import Dict
-
-import matplotlib.pyplot as plt
-import pandas as pd
-import seaborn as sns
-import tqdm
-from zhon import hanzi
-
-
-def _make_w_io_base(f, mode: str):
- if not isinstance(f, io.IOBase):
- f_dirname = os.path.dirname(f)
- if f_dirname != "":
- os.makedirs(f_dirname, exist_ok=True)
- f = open(f, mode=mode)
- return f
-
-
-def _make_r_io_base(f, mode: str):
- if not isinstance(f, io.IOBase):
- f = open(f, mode=mode)
- return f
-
-
-def jdump(obj, f, mode="w", indent=4, default=str):
- """Dump a str or dictionary to a file in json format.
- Args:
- obj: An object to be written.
- f: A string path to the location on disk.
- mode: Mode for opening the file.
- indent: Indent for storing json dictionaries.
- default: A function to handle non-serializable entries; defaults to `str`.
- """
- f = _make_w_io_base(f, mode)
- if isinstance(obj, (dict, list)):
- json.dump(obj, f, indent=indent, default=default, ensure_ascii=False)
- elif isinstance(obj, str):
- f.write(obj)
- else:
- raise ValueError(f"Unexpected type: {type(obj)}")
- f.close()
-
-
-def jload(f, mode="r"):
- """Load a .json file into a dictionary."""
- f = _make_r_io_base(f, mode)
- jdict = json.load(f)
- f.close()
- return jdict
-
-
-def get_json_list(file_path):
- with open(file_path, 'r') as f:
- json_list = []
- for line in f:
- json_list.append(json.loads(line))
- return json_list
-
-
-def get_data_per_category(data, categories):
- data_per_category = {category: [] for category in categories}
- for item in data:
- category = item["category"]
- if category in categories:
- data_per_category[category].append(item)
-
- return data_per_category
-
-
-def remove_punctuations(text: str) -> str:
- """
- Remove punctuations in the given text.
- It is used in evaluation of automatic metrics.
-
- """
-
- punctuation = string.punctuation + hanzi.punctuation
- punctuation = set([char for char in punctuation])
- punctuation.difference_update(set("!@#$%&()<>?|,.\"'"))
-
- out = []
- for char in text:
- if char in punctuation:
- continue
- else:
- out.append(char)
-
- return "".join(out)
-
-
-def remove_redundant_space(text: str) -> str:
- """
- Remove redundant spaces in the given text.
- It is used in evaluation of automatic metrics.
-
- """
-
- return " ".join(text.split())
-
-
-def preprocessing_text(text: str) -> str:
- """
- Preprocess the given text.
- It is used in evaluation of automatic metrics.
-
- """
-
- return remove_redundant_space(remove_punctuations(text.lower()))
-
-
-def save_automatic_results(model_name: str, automatic_metric_stats: Dict[str, Dict], save_path: str) -> None:
- """
- Save automatic evaluation results of different categories for one model.
-
- """
-
- if not os.path.exists(save_path):
- os.makedirs(save_path)
-
- automatic_df = pd.DataFrame(automatic_metric_stats)
- automatic_df.to_csv(os.path.join(save_path, f"{model_name}_results.csv"), index=True)
-
-
-def read_automatic_results(results_path: str, file_name: str) -> Dict[str, Dict]:
- """
- Read a csv file and return a dictionary which stores scores per metric.
-
- """
-
- results = pd.read_csv(os.path.join(results_path, file_name), index_col=0)
-
- results_dict = {metric: {} for metric in list(results.index)}
- for i, metric in enumerate(results_dict.keys()):
- for j, category in enumerate(list(results.columns)):
- if pd.isnull(results.iloc[i][j]):
- continue
- results_dict[metric][category] = results.iloc[i][j]
-
- return results_dict
-
-
-def analyze_automatic_results(results_path: str, save_path: str) -> None:
- """
- Analyze and visualize all csv files in the given folder.
-
- """
-
- if not os.path.exists(results_path):
- raise Exception(f'The given directory "{results_path}" doesn\'t exist! No results found!')
-
- all_statistics = {}
-
- for file_name in os.listdir(results_path):
- if file_name.endswith("_results.csv"):
- model_name = file_name.split("_results.csv")[0]
- all_statistics[model_name] = read_automatic_results(results_path, file_name)
-
- if len(list(all_statistics.keys())) == 0:
- raise Exception(f'There are no csv files in the given directory "{results_path}"!')
-
- frame_all = {"model": [], "category": [], "metric": [], "score": []}
- frame_per_metric = {}
- for model_name, model_statistics in all_statistics.items():
- for metric, metric_statistics in model_statistics.items():
- if frame_per_metric.get(metric) is None:
- frame_per_metric[metric] = {"model": [], "category": [], "score": []}
-
- for category, category_score in metric_statistics.items():
- frame_all["model"].append(model_name)
- frame_all["category"].append(category)
- frame_all["metric"].append(metric)
- frame_all["score"].append(category_score)
-
- frame_per_metric[metric]["model"].append(model_name)
- frame_per_metric[metric]["category"].append(category)
- frame_per_metric[metric]["score"].append(category_score)
-
- if not os.path.exists(save_path):
- os.makedirs(save_path)
-
- frame_all = pd.DataFrame(frame_all)
- frame_all.to_csv(os.path.join(save_path, "automatic_evaluation_statistics.csv"))
-
- for metric in tqdm.tqdm(
- frame_per_metric.keys(),
- desc=f"automatic metrics: ",
- total=len(frame_per_metric.keys()),
- ):
- data = pd.DataFrame(frame_per_metric[metric])
-
- sns.set()
- fig = plt.figure(figsize=(16, 10))
-
- fig = sns.barplot(x="category", y="score", hue="model", data=data, dodge=True)
- fig.set_title(f"Comparison between Different Models for Metric {metric.title()}")
- plt.xlabel("Evaluation Category")
- plt.ylabel("Score")
-
- figure = fig.get_figure()
- figure.savefig(os.path.join(save_path, f"{metric}.png"), dpi=400)
-
- plt.close()
diff --git a/applications/Chat/examples/community/peft/README.md b/applications/Chat/examples/community/peft/README.md
index 8b2edc48cd99..ada3a16296af 100644
--- a/applications/Chat/examples/community/peft/README.md
+++ b/applications/Chat/examples/community/peft/README.md
@@ -20,7 +20,7 @@ pip install .
For SFT training, just call train_peft_sft.py
-Its arguments are almost identical to train_sft.py instead adding a new eval_dataset if you have a eval_dataset file. The data file is just a plain datafile, please check the format in the easy_dataset.py.
+Its arguments are almost identical to train_sft.py instead adding a new eval_dataset if you have an eval_dataset file. The data file is just a plain datafile, please check the format in the easy_dataset.py.
For stage-3 rlhf training, call train_peft_prompts.py.
Its arguments are almost identical to train_prompts.py. The only difference is that I use text files to indicate the prompt and pretrained data file. The models are included in easy_models.py. Currently only bloom models are tested, but technically gpt2/opt/llama should be supported.
diff --git a/applications/Chat/examples/community/peft/easy_dataset.py b/applications/Chat/examples/community/peft/easy_dataset.py
index 2fe293957079..d4b17689e9cb 100644
--- a/applications/Chat/examples/community/peft/easy_dataset.py
+++ b/applications/Chat/examples/community/peft/easy_dataset.py
@@ -3,7 +3,6 @@
from typing import Dict, Sequence
import torch
-from datasets import load_dataset
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import AutoTokenizer
@@ -20,7 +19,8 @@ def _tokenize_fn(strings: Sequence[str], tokenizer: AutoTokenizer, max_length: i
padding="longest",
max_length=max_length,
truncation=True,
- ) for text in strings
+ )
+ for text in strings
]
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
input_ids_lens = labels_lens = [
@@ -48,18 +48,17 @@ def preprocess(sources: Sequence[str], targets: Sequence[str], tokenizer: AutoTo
class EasySupervisedDataset(Dataset):
-
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 512) -> None:
super(EasySupervisedDataset, self).__init__()
with open(data_file, "r", encoding="UTF-8") as f:
all_lines = f.readlines()
- #split to source and target ,source the characters before "回答:" including "回答:", target the characters after "回答:"
+ # split to source and target ,source the characters before "回答:" including "回答:", target the characters after "回答:"
sources, targets = [], []
for line in all_lines:
if "回答:" in line:
sep_index = line.index("回答:")
- sources.append(line[:sep_index + 3])
- targets.append(line[sep_index + 3:] + tokenizer.eos_token)
+ sources.append(line[: sep_index + 3])
+ targets.append(line[sep_index + 3 :] + tokenizer.eos_token)
else:
sources.append(line)
targets.append("" + tokenizer.eos_token)
@@ -83,15 +82,17 @@ def __str__(self):
class EasyPromptsDataset(Dataset):
-
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 96) -> None:
super(EasyPromptsDataset, self).__init__()
with open(data_file, "r", encoding="UTF-8") as f:
all_lines = f.readlines()
- all_lines = [line if "回答:" not in line else line[:line.index("回答:") + 3] for line in all_lines]
+ all_lines = [line if "回答:" not in line else line[: line.index("回答:") + 3] for line in all_lines]
self.prompts = [
- tokenizer(line, return_tensors='pt', max_length=max_length, padding='max_length',
- truncation=True)['input_ids'].to(torch.cuda.current_device()).squeeze(0)
+ tokenizer(line, return_tensors="pt", max_length=max_length, padding="max_length", truncation=True)[
+ "input_ids"
+ ]
+ .to(torch.cuda.current_device())
+ .squeeze(0)
for line in tqdm(all_lines)
]
self.data_file = data_file
@@ -110,7 +111,6 @@ def __str__(self):
class EasyRewardDataset(Dataset):
-
def __init__(self, train_file: str, tokenizer: AutoTokenizer, special_token=None, max_length=512) -> None:
super(EasyRewardDataset, self).__init__()
self.chosen = []
@@ -120,44 +120,42 @@ def __init__(self, train_file: str, tokenizer: AutoTokenizer, special_token=None
else:
self.end_token = special_token
print(self.end_token)
- #read all lines in the train_file to a list
+ # read all lines in the train_file to a list
with open(train_file, "r", encoding="UTF-8") as f:
all_lines = f.readlines()
for line in tqdm(all_lines):
data = json.loads(line)
- prompt = "提问:" + data['prompt'] + " 回答:"
-
- chosen = prompt + data['chosen'] + self.end_token
- chosen_token = tokenizer(chosen,
- max_length=max_length,
- padding="max_length",
- truncation=True,
- return_tensors="pt")
- self.chosen.append({
- "input_ids": chosen_token['input_ids'],
- "attention_mask": chosen_token['attention_mask']
- })
-
- reject = prompt + data['rejected'] + self.end_token
- reject_token = tokenizer(reject,
- max_length=max_length,
- padding="max_length",
- truncation=True,
- return_tensors="pt")
- self.reject.append({
- "input_ids": reject_token['input_ids'],
- "attention_mask": reject_token['attention_mask']
- })
+ prompt = "提问:" + data["prompt"] + " 回答:"
+
+ chosen = prompt + data["chosen"] + self.end_token
+ chosen_token = tokenizer(
+ chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ self.chosen.append(
+ {"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]}
+ )
+
+ reject = prompt + data["rejected"] + self.end_token
+ reject_token = tokenizer(
+ reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ self.reject.append(
+ {"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]}
+ )
def __len__(self):
length = len(self.chosen)
return length
def __getitem__(self, idx):
- return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][
- "input_ids"], self.reject[idx]["attention_mask"]
-
- #python representation of the object and the string representation of the object
+ return (
+ self.chosen[idx]["input_ids"],
+ self.chosen[idx]["attention_mask"],
+ self.reject[idx]["input_ids"],
+ self.reject[idx]["attention_mask"],
+ )
+
+ # python representation of the object and the string representation of the object
def __repr__(self):
return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})"
@@ -165,26 +163,25 @@ def __str__(self):
return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})"
-'''
+"""
Easy SFT just accept a text file which can be read line by line. However the datasets will group texts together to max_length so LLM will learn the texts meaning better.
If individual lines are not related, just set is_group_texts to False.
-'''
+"""
class EasySFTDataset(Dataset):
-
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length=512, is_group_texts=True) -> None:
super().__init__()
- #read the data_file line by line
+ # read the data_file line by line
with open(data_file, "r", encoding="UTF-8") as f:
- #encode the text data line by line and put raw python list input_ids only to raw_input_ids list
+ # encode the text data line by line and put raw python list input_ids only to raw_input_ids list
raw_input_ids = []
for line in f:
encoded_ids = tokenizer.encode(line)
- #if the encoded_ids is longer than max_length, then split it into several parts
+ # if the encoded_ids is longer than max_length, then split it into several parts
if len(encoded_ids) > max_length:
for i in range(0, len(encoded_ids), max_length):
- raw_input_ids.append(encoded_ids[i:i + max_length])
+ raw_input_ids.append(encoded_ids[i : i + max_length])
else:
raw_input_ids.append(encoded_ids)
@@ -196,12 +193,13 @@ def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length=512, is_
if is_group_texts:
for input_ids in raw_input_ids:
if len(current_input_ids) + len(input_ids) > max_length:
- #pad the current_input_ids to max_length with tokenizer.pad_token_id
+ # pad the current_input_ids to max_length with tokenizer.pad_token_id
padded_length = max_length - len(current_input_ids)
current_input_ids.extend([tokenizer.pad_token_id] * padded_length)
grouped_input_ids.append(torch.tensor(current_input_ids, dtype=torch.long))
attention_mask.append(
- torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long))
+ torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)
+ )
current_input_ids = []
else:
current_input_ids.extend(input_ids)
@@ -210,14 +208,16 @@ def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length=512, is_
current_input_ids.extend([tokenizer.pad_token_id] * padded_length)
grouped_input_ids.append(torch.tensor(current_input_ids, dtype=torch.long))
attention_mask.append(
- torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long))
+ torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)
+ )
else:
- #just append the raw_input_ids to max_length
+ # just append the raw_input_ids to max_length
for input_ids in raw_input_ids:
padded_length = max_length - len(input_ids)
input_ids.extend([tokenizer.pad_token_id] * padded_length)
attention_mask.append(
- torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long))
+ torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)
+ )
grouped_input_ids.append(torch.tensor(input_ids, dtype=torch.long))
self.input_ids = grouped_input_ids
self.labels = copy.deepcopy(self.input_ids)
@@ -227,14 +227,14 @@ def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length=512, is_
def __len__(self):
return len(self.input_ids)
- #get item from dataset
+ # get item from dataset
def __getitem__(self, idx):
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx])
- #generate the dataset description to be printed by print in python
+ # generate the dataset description to be printed by print in python
def __repr__(self):
return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})"
- #generate the dataset description to be printed by print in python
+ # generate the dataset description to be printed by print in python
def __str__(self):
return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})"
diff --git a/applications/Chat/examples/community/peft/easy_models.py b/applications/Chat/examples/community/peft/easy_models.py
index fe294868159d..db629e50ed94 100644
--- a/applications/Chat/examples/community/peft/easy_models.py
+++ b/applications/Chat/examples/community/peft/easy_models.py
@@ -4,7 +4,7 @@
import torch.nn as nn
import torch.nn.functional as F
from coati.models.generation import generate
-from coati.models.utils import log_probs_from_logits, masked_mean
+from coati.models.utils import log_probs_from_logits
from peft import PeftModel
from torch.nn.modules import Module
from transformers import BloomConfig, BloomForCausalLM
@@ -24,38 +24,33 @@ def __init__(self, model: nn.Module) -> None:
@torch.no_grad()
def generate(
- self,
- input_ids: torch.Tensor,
- return_action_mask: bool = True,
- **kwargs
+ self, input_ids: torch.Tensor, return_action_mask: bool = True, **kwargs
) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
sequences = generate(self.model, input_ids, **kwargs)
attention_mask = None
- pad_token_id = kwargs.get('pad_token_id', None)
+ pad_token_id = kwargs.get("pad_token_id", None)
if pad_token_id is not None:
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
if not return_action_mask:
return sequences, attention_mask, None
input_len = input_ids.size(1)
- eos_token_id = kwargs.get('eos_token_id', None)
+ eos_token_id = kwargs.get("eos_token_id", None)
if eos_token_id is None:
action_mask = torch.ones_like(sequences, dtype=torch.bool)
else:
# left padding may be applied, only mask action
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
- action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
+ action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
action_mask[:, :input_len] = False
action_mask = action_mask[:, 1:]
- return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):]
+ return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len) :]
- def forward(self,
- sequences: torch.LongTensor,
- num_actions: int,
- attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
- """Returns action log probs
- """
+ def forward(
+ self, sequences: torch.LongTensor, num_actions: int, attention_mask: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ """Returns action log probs"""
output = self.model(sequences, attention_mask=attention_mask)
- logits = output['logits']
+ logits = output["logits"]
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
return log_probs[:, -num_actions:]
@@ -75,11 +70,13 @@ class BLOOMActor(Actor):
lora_train_bias (str): LoRA bias training mode.
"""
- def __init__(self,
- pretrained: str = None,
- config: Optional[BloomConfig] = None,
- checkpoint: bool = False,
- lora_path: str = None) -> None:
+ def __init__(
+ self,
+ pretrained: str = None,
+ config: Optional[BloomConfig] = None,
+ checkpoint: bool = False,
+ lora_path: str = None,
+ ) -> None:
if pretrained is not None:
model = BloomForCausalLM.from_pretrained(pretrained)
elif config is not None:
diff --git a/applications/Chat/examples/community/peft/train_peft_prompts.py b/applications/Chat/examples/community/peft/train_peft_prompts.py
index 9385e457d852..99a024f1463c 100644
--- a/applications/Chat/examples/community/peft/train_peft_prompts.py
+++ b/applications/Chat/examples/community/peft/train_peft_prompts.py
@@ -1,18 +1,16 @@
import argparse
-import pandas as pd
import torch
import torch.distributed as dist
-from coati.dataset import DataCollatorForSupervisedDataset, PromptDataset, SupervisedDataset
+from coati.dataset import DataCollatorForSupervisedDataset
from coati.models.bloom import BLOOMRM, BLOOMCritic
-from coati.models.gpt import GPTRM, GPTActor, GPTCritic
-from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
-from coati.models.opt import OPTRM, OPTActor, OPTCritic
+from coati.models.gpt import GPTRM, GPTCritic
+from coati.models.llama import LlamaCritic, LlamaRM
+from coati.models.opt import OPTRM, OPTCritic
from coati.trainer import PPOTrainer
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
from easy_dataset import EasyPromptsDataset, EasySupervisedDataset
from easy_models import BLOOMActor
-from peft import PeftModel
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
@@ -23,24 +21,24 @@
def main(args):
# configure strategy
- if args.strategy == 'ddp':
+ if args.strategy == "ddp":
strategy = DDPStrategy()
- elif args.strategy == 'colossalai_gemini':
- strategy = GeminiStrategy(placement_policy='cpu', initial_scale=2**5)
- elif args.strategy == 'colossalai_zero2':
- strategy = LowLevelZeroStrategy(stage=2, placement_policy='cpu')
+ elif args.strategy == "colossalai_gemini":
+ strategy = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5)
+ elif args.strategy == "colossalai_zero2":
+ strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
else:
raise ValueError(f'Unsupported strategy "{args.strategy}"')
if args.rm_path is not None:
- state_dict = torch.load(args.rm_path, map_location='cpu')
+ state_dict = torch.load(args.rm_path, map_location="cpu")
# configure model
- if args.model == 'bloom':
+ if args.model == "bloom":
# initial_model = BLOOMActor(pretrained=args.pretrain)
- print('Using peft lora to load Bloom model as initial_model')
+ print("Using peft lora to load Bloom model as initial_model")
initial_model = BLOOMActor(pretrained=args.pretrain, lora_path=args.sft_lora_path)
- print('Using peft lora to load Bloom model as initial_model (Done)')
+ print("Using peft lora to load Bloom model as initial_model (Done)")
else:
raise ValueError(f'Unsupported actor model "{args.model}"')
@@ -49,59 +47,59 @@ def main(args):
else:
rm_model_name = args.rm_model
- if rm_model_name == 'gpt2':
+ if rm_model_name == "gpt2":
reward_model = GPTRM(pretrained=args.rm_pretrain)
- elif rm_model_name == 'bloom':
+ elif rm_model_name == "bloom":
print("load bloom reward model ", args.rm_pretrain)
reward_model = BLOOMRM(pretrained=args.rm_pretrain)
- elif rm_model_name == 'opt':
+ elif rm_model_name == "opt":
reward_model = OPTRM(pretrained=args.rm_pretrain)
- elif rm_model_name == 'llama':
+ elif rm_model_name == "llama":
reward_model = LlamaRM(pretrained=args.rm_pretrain)
else:
raise ValueError(f'Unsupported reward model "{rm_model_name}"')
if args.rm_path is not None:
- print('Loading reward model from', args.rm_path)
+ print("Loading reward model from", args.rm_path)
reward_model.load_state_dict(state_dict)
- if args.strategy != 'colossalai_gemini':
+ if args.strategy != "colossalai_gemini":
initial_model.to(torch.float16).to(torch.cuda.current_device())
reward_model.to(torch.float16).to(torch.cuda.current_device())
with strategy.model_init_context():
- if args.model == 'bloom':
+ if args.model == "bloom":
# actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
- print('Using peft lora to load Bloom model as Actor')
+ print("Using peft lora to load Bloom model as Actor")
actor = BLOOMActor(pretrained=args.pretrain, lora_path=args.sft_lora_path)
- print('Using peft lora to load Bloom model as Actor (Done)')
+ print("Using peft lora to load Bloom model as Actor (Done)")
else:
raise ValueError(f'Unsupported actor model "{args.model}"')
- if rm_model_name == 'gpt2':
+ if rm_model_name == "gpt2":
critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
- elif rm_model_name == 'bloom':
+ elif rm_model_name == "bloom":
print("load bloom critic ", args.rm_pretrain, " lora_rank ", args.lora_rank, " use_action_mask ", True)
critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
print("load bloom critic (Done) ")
- elif rm_model_name == 'opt':
+ elif rm_model_name == "opt":
critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
- elif rm_model_name == 'llama':
+ elif rm_model_name == "llama":
critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
else:
raise ValueError(f'Unsupported reward model "{rm_model_name}"')
if args.rm_path is not None:
- print('Loading reward model from', args.rm_path)
+ print("Loading reward model from", args.rm_path)
critic.load_state_dict(state_dict)
del state_dict
- if args.strategy != 'colossalai_gemini':
+ if args.strategy != "colossalai_gemini":
critic.to(torch.float16).to(torch.cuda.current_device())
actor.to(torch.float16).to(torch.cuda.current_device())
# configure optimizer
- if args.strategy.startswith('colossalai'):
+ if args.strategy.startswith("colossalai"):
actor_optim = HybridAdam(actor.parameters(), lr=1e-7)
critic_optim = HybridAdam(critic.parameters(), lr=1e-7)
else:
@@ -109,18 +107,18 @@ def main(args):
critic_optim = Adam(critic.parameters(), lr=1e-7)
# configure tokenizer
- if args.model == 'gpt2':
+ if args.model == "gpt2":
tokenizer = GPT2Tokenizer.from_pretrained(args.rm_pretrain)
tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'bloom':
+ elif args.model == "bloom":
tokenizer = BloomTokenizerFast.from_pretrained(args.rm_pretrain)
tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'opt':
+ elif args.model == "opt":
tokenizer = AutoTokenizer.from_pretrained(args.rm_pretrain)
tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'llama':
+ elif args.model == "llama":
tokenizer = LlamaTokenizer.from_pretrained(args.pretrain)
- tokenizer.eos_token = '<\s>'
+ tokenizer.eos_token = "<\s>"
tokenizer.pad_token = tokenizer.unk_token
else:
raise ValueError(f'Unsupported model "{args.model}"')
@@ -132,26 +130,27 @@ def main(args):
prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True)
else:
prompt_sampler = None
- prompt_dataloader = DataLoader(prompt_dataset,
- shuffle=(prompt_sampler is None),
- sampler=prompt_sampler,
- batch_size=args.train_batch_size)
+ prompt_dataloader = DataLoader(
+ prompt_dataset, shuffle=(prompt_sampler is None), sampler=prompt_sampler, batch_size=args.train_batch_size
+ )
pretrain_dataset = EasySupervisedDataset(args.pretrain_dataset, tokenizer)
if dist.is_initialized() and dist.get_world_size() > 1:
pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True)
else:
pretrain_sampler = None
- pretrain_dataloader = DataLoader(pretrain_dataset,
- shuffle=(pretrain_sampler is None),
- sampler=pretrain_sampler,
- batch_size=args.ptx_batch_size,
- collate_fn=data_collator)
+ pretrain_dataloader = DataLoader(
+ pretrain_dataset,
+ shuffle=(pretrain_sampler is None),
+ sampler=pretrain_sampler,
+ batch_size=args.ptx_batch_size,
+ collate_fn=data_collator,
+ )
def tokenize_fn(texts):
# MUST padding to max length to ensure inputs of all ranks have the same length
# Different length may lead to hang when using gemini, as different generation steps
- batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True)
+ batch = tokenizer(texts, return_tensors="pt", max_length=96, padding="max_length", truncation=True)
return {k: v.to(torch.cuda.current_device()) for k, v in batch.items()}
(actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim))
@@ -178,45 +177,46 @@ def tokenize_fn(texts):
eos_token_id=tokenizer.eos_token_id,
)
- trainer.fit(prompt_dataloader=prompt_dataloader,
- pretrain_dataloader=pretrain_dataloader,
- num_episodes=args.num_episodes,
- num_update_steps=args.num_update_steps,
- num_collect_steps=args.num_collect_steps)
+ trainer.fit(
+ prompt_dataloader=prompt_dataloader,
+ pretrain_dataloader=pretrain_dataloader,
+ num_episodes=args.num_episodes,
+ num_update_steps=args.num_update_steps,
+ num_collect_steps=args.num_collect_steps,
+ )
# save model checkpoint after fitting
trainer.save_model(args.save_path, only_rank0=True, tokenizer=tokenizer)
# save optimizer checkpoint on all ranks
if args.need_optim_ckpt:
- strategy.save_optimizer(actor_optim,
- 'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
- only_rank0=False)
+ strategy.save_optimizer(
+ actor_optim, "actor_optim_checkpoint_prompts_%d.pt" % (torch.cuda.current_device()), only_rank0=False
+ )
-if __name__ == '__main__':
+if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('--prompt_path', type=str, default=None, help='path to the prompt dataset')
- parser.add_argument('--pretrain_dataset', type=str, default=None, help='path to the pretrained dataset')
- parser.add_argument('--strategy',
- choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'],
- default='ddp',
- help='strategy to use')
- parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
- parser.add_argument('--pretrain', type=str, default=None)
- parser.add_argument('--sft_lora_path', type=str, default=None)
- parser.add_argument('--rm_model', default=None, choices=['gpt2', 'bloom', 'opt', 'llama'])
- parser.add_argument('--rm_path', type=str, default=None)
- parser.add_argument('--rm_pretrain', type=str, default=None)
- parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts')
- parser.add_argument('--need_optim_ckpt', type=bool, default=False)
- parser.add_argument('--num_episodes', type=int, default=10)
- parser.add_argument('--num_collect_steps', type=int, default=10)
- parser.add_argument('--num_update_steps', type=int, default=5)
- parser.add_argument('--train_batch_size', type=int, default=2)
- parser.add_argument('--ptx_batch_size', type=int, default=1)
- parser.add_argument('--experience_batch_size', type=int, default=8)
- parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
- parser.add_argument('--kl_coef', type=float, default=0.1)
- parser.add_argument('--ptx_coef', type=float, default=0.9)
+ parser.add_argument("--prompt_path", type=str, default=None, help="path to the prompt dataset")
+ parser.add_argument("--pretrain_dataset", type=str, default=None, help="path to the pretrained dataset")
+ parser.add_argument(
+ "--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="ddp", help="strategy to use"
+ )
+ parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
+ parser.add_argument("--pretrain", type=str, default=None)
+ parser.add_argument("--sft_lora_path", type=str, default=None)
+ parser.add_argument("--rm_model", default=None, choices=["gpt2", "bloom", "opt", "llama"])
+ parser.add_argument("--rm_path", type=str, default=None)
+ parser.add_argument("--rm_pretrain", type=str, default=None)
+ parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts")
+ parser.add_argument("--need_optim_ckpt", type=bool, default=False)
+ parser.add_argument("--num_episodes", type=int, default=10)
+ parser.add_argument("--num_collect_steps", type=int, default=10)
+ parser.add_argument("--num_update_steps", type=int, default=5)
+ parser.add_argument("--train_batch_size", type=int, default=2)
+ parser.add_argument("--ptx_batch_size", type=int, default=1)
+ parser.add_argument("--experience_batch_size", type=int, default=8)
+ parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
+ parser.add_argument("--kl_coef", type=float, default=0.1)
+ parser.add_argument("--ptx_coef", type=float, default=0.9)
args = parser.parse_args()
main(args)
diff --git a/applications/Chat/examples/community/peft/train_peft_sft.py b/applications/Chat/examples/community/peft/train_peft_sft.py
index 4af08e6d0141..3bbef7208374 100644
--- a/applications/Chat/examples/community/peft/train_peft_sft.py
+++ b/applications/Chat/examples/community/peft/train_peft_sft.py
@@ -1,18 +1,10 @@
import argparse
import os
-import loralib as lora
import torch
import torch.distributed as dist
-from coati.dataset import DataCollatorForSupervisedDataset, SFTDataset, SupervisedDataset
-from coati.models.base import RewardModel
-from coati.models.bloom import BLOOMLM
-from coati.models.gpt import GPTLM
-from coati.models.llama import LlamaLM
-from coati.models.opt import OPTLM
from coati.trainer import SFTTrainer
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
-from datasets import load_dataset
from easy_dataset import EasyDataset
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
from torch.optim import Adam
@@ -29,75 +21,76 @@
def train(args):
# configure strategy
- if args.strategy == 'ddp':
+ if args.strategy == "ddp":
strategy = DDPStrategy()
- elif args.strategy == 'colossalai_gemini':
- strategy = GeminiStrategy(placement_policy='cuda')
- elif args.strategy == 'colossalai_zero2':
- strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
+ elif args.strategy == "colossalai_gemini":
+ strategy = GeminiStrategy(placement_policy="static")
+ elif args.strategy == "colossalai_zero2":
+ strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
else:
raise ValueError(f'Unsupported strategy "{args.strategy}"')
# configure model
with strategy.model_init_context():
- print('Warning: currently only bloom is tested, gpt2,llama and opt are not tested')
+ print("Warning: currently only bloom is tested, gpt2,llama and opt are not tested")
model = AutoModelForCausalLM.from_pretrained(args.pretrain).to(torch.cuda.current_device())
# if the args.save_path exists and args.save_path+'/adapter_config.json' exists, we'll load the adapter_config.json
- if os.path.exists(args.save_path) and os.path.exists(args.save_path + '/adapter_config.json') \
- and os.path.exists(args.save_path + '/adapter_model.bin'):
+ if (
+ os.path.exists(args.save_path)
+ and os.path.exists(args.save_path + "/adapter_config.json")
+ and os.path.exists(args.save_path + "/adapter_model.bin")
+ ):
print("loading from saved peft model ", args.save_path)
model = PeftModel.from_pretrained(model, args.save_path)
else:
# we'll use peft lora library to do the lora
lora_rank = args.lora_rank if args.lora_rank > 0 else 32
# config lora with rank of lora_rank
- lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM,
- inference_mode=False,
- r=lora_rank,
- lora_alpha=32,
- lora_dropout=0.1)
+ lora_config = LoraConfig(
+ task_type=TaskType.CAUSAL_LM, inference_mode=False, r=lora_rank, lora_alpha=32, lora_dropout=0.1
+ )
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# configure tokenizer
- if args.model == 'gpt2':
- tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
+ if args.model == "gpt2":
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'bloom':
+ elif args.model == "bloom":
tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'opt':
+ elif args.model == "opt":
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'llama':
+ elif args.model == "llama":
tokenizer = AutoTokenizer.from_pretrained(
args.pretrain,
padding_side="right",
use_fast=False,
)
- tokenizer.eos_token = '<\s>'
+ tokenizer.eos_token = "<\s>"
tokenizer.pad_token = tokenizer.unk_token
else:
raise ValueError(f'Unsupported model "{args.model}"')
- if args.model == 'llama' and args.strategy == 'colossalai_gemini':
+ if args.model == "llama" and args.strategy == "colossalai_gemini":
# this is a hack to deal with the resized embedding
# to make sure all parameters are ColoParameter for Colossal-AI Gemini Compatibility
for name, param in model.named_parameters():
if not isinstance(param, ColoParameter):
- sub_module_name = '.'.join(name.split('.')[:-1])
- weight_name = name.split('.')[-1]
+ sub_module_name = ".".join(name.split(".")[:-1])
+ weight_name = name.split(".")[-1]
sub_module = model.get_submodule(sub_module_name)
setattr(sub_module, weight_name, ColoParameter(param))
# configure optimizer
- if args.strategy.startswith('colossalai'):
+ if args.strategy.startswith("colossalai"):
optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0)
else:
optim = Adam(model.parameters(), lr=args.lr)
logger = get_dist_logger()
- logger.set_level('WARNING')
+ logger.set_level("WARNING")
# configure dataset
law_dataset = EasyDataset(args.dataset, tokenizer=tokenizer, is_group_texts=not args.is_short_text)
@@ -108,47 +101,57 @@ def train(args):
eval_dataset = EasyDataset(args.eval_dataset, tokenizer=tokenizer, is_group_texts=not args.is_short_text)
data_collator = default_collate
if dist.is_initialized() and dist.get_world_size() > 1:
- train_sampler = DistributedSampler(train_dataset,
- shuffle=True,
- seed=42,
- drop_last=True,
- rank=dist.get_rank(),
- num_replicas=dist.get_world_size())
+ train_sampler = DistributedSampler(
+ train_dataset,
+ shuffle=True,
+ seed=42,
+ drop_last=True,
+ rank=dist.get_rank(),
+ num_replicas=dist.get_world_size(),
+ )
if eval_dataset is not None:
- eval_sampler = DistributedSampler(eval_dataset,
- shuffle=False,
- seed=42,
- drop_last=False,
- rank=dist.get_rank(),
- num_replicas=dist.get_world_size())
+ eval_sampler = DistributedSampler(
+ eval_dataset,
+ shuffle=False,
+ seed=42,
+ drop_last=False,
+ rank=dist.get_rank(),
+ num_replicas=dist.get_world_size(),
+ )
else:
train_sampler = None
eval_sampler = None
- train_dataloader = DataLoader(train_dataset,
- shuffle=(train_sampler is None),
- sampler=train_sampler,
- batch_size=args.batch_size,
- collate_fn=data_collator,
- pin_memory=True)
+ train_dataloader = DataLoader(
+ train_dataset,
+ shuffle=(train_sampler is None),
+ sampler=train_sampler,
+ batch_size=args.batch_size,
+ collate_fn=data_collator,
+ pin_memory=True,
+ )
if eval_dataset is not None:
- eval_dataloader = DataLoader(eval_dataset,
- shuffle=(eval_sampler is None),
- sampler=eval_sampler,
- batch_size=args.batch_size,
- collate_fn=data_collator,
- pin_memory=True)
+ eval_dataloader = DataLoader(
+ eval_dataset,
+ shuffle=(eval_sampler is None),
+ sampler=eval_sampler,
+ batch_size=args.batch_size,
+ collate_fn=data_collator,
+ pin_memory=True,
+ )
else:
eval_dataloader = None
- trainer = SFTTrainer(model=model,
- strategy=strategy,
- optim=optim,
- train_dataloader=train_dataloader,
- eval_dataloader=eval_dataloader,
- batch_size=args.batch_size,
- max_epochs=args.max_epochs,
- accumulation_steps=args.accumulation_steps)
+ trainer = SFTTrainer(
+ model=model,
+ strategy=strategy,
+ optim=optim,
+ train_dataloader=train_dataloader,
+ eval_dataloader=eval_dataloader,
+ batch_size=args.batch_size,
+ max_epochs=args.max_epochs,
+ accumulation_steps=args.accumulation_steps,
+ )
trainer.fit(logger=logger, log_interval=args.log_interval)
@@ -156,29 +159,27 @@ def train(args):
trainer.save_model(path=args.save_path, only_rank0=True, tokenizer=tokenizer)
# save optimizer checkpoint on all ranks
if args.need_optim_ckpt:
- strategy.save_optimizer(trainer.optimizer,
- 'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()),
- only_rank0=False)
+ strategy.save_optimizer(
+ trainer.optimizer, "rm_optim_checkpoint_%d.pt" % (torch.cuda.current_device()), only_rank0=False
+ )
-if __name__ == '__main__':
+if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('--strategy',
- choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'],
- default='ddp')
- parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom')
- parser.add_argument('--pretrain', type=str, default=None)
- parser.add_argument('--dataset', type=str, default=None)
- parser.add_argument('--eval_dataset', type=str, default=None)
- parser.add_argument('--save_path', type=str, default='output')
- parser.add_argument('--need_optim_ckpt', type=bool, default=False)
- parser.add_argument('--max_epochs', type=int, default=3)
- parser.add_argument('--batch_size', type=int, default=4)
- parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
- parser.add_argument('--log_interval', type=int, default=100, help="how many steps to log")
- parser.add_argument('--lr', type=float, default=5e-6)
- parser.add_argument('--accumulation_steps', type=int, default=8)
- parser.add_argument('--enable_peft_lora', action='store_true', default=False)
- parser.add_argument("--is_short_text", action='store_true', default=False)
+ parser.add_argument("--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="ddp")
+ parser.add_argument("--model", choices=["gpt2", "bloom", "opt", "llama"], default="bloom")
+ parser.add_argument("--pretrain", type=str, default=None)
+ parser.add_argument("--dataset", type=str, default=None)
+ parser.add_argument("--eval_dataset", type=str, default=None)
+ parser.add_argument("--save_path", type=str, default="output")
+ parser.add_argument("--need_optim_ckpt", type=bool, default=False)
+ parser.add_argument("--max_epochs", type=int, default=3)
+ parser.add_argument("--batch_size", type=int, default=4)
+ parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
+ parser.add_argument("--log_interval", type=int, default=100, help="how many steps to log")
+ parser.add_argument("--lr", type=float, default=5e-6)
+ parser.add_argument("--accumulation_steps", type=int, default=8)
+ parser.add_argument("--enable_peft_lora", action="store_true", default=False)
+ parser.add_argument("--is_short_text", action="store_true", default=False)
args = parser.parse_args()
train(args)
diff --git a/applications/Chat/examples/community/ray/ray_job_script.py b/applications/Chat/examples/community/ray/ray_job_script.py
index 53f304d379fe..e8a1175a9c32 100644
--- a/applications/Chat/examples/community/ray/ray_job_script.py
+++ b/applications/Chat/examples/community/ray/ray_job_script.py
@@ -6,16 +6,25 @@
def main(api_server_endpoint="http://127.0.0.1:8265"):
client = JobSubmissionClient(api_server_endpoint)
client.submit_job(
- entrypoint=
- "python experimental/ray/train_prompts_on_ray.py --strategy colossalai_zero2 --prompt_csv_url https://huggingface.co/datasets/fka/awesome-chatgpt-prompts/resolve/main/prompts.csv",
+ entrypoint="python experimental/ray/train_prompts_on_ray.py --strategy colossalai_zero2 --prompt_csv_url https://huggingface.co/datasets/fka/awesome-chatgpt-prompts/resolve/main/prompts.csv",
runtime_env={
- "working_dir":
- "applications/Chat",
+ "working_dir": "applications/Chat",
"pip": [
- "torch==1.13.1", "transformers>=4.20.1", "datasets", "loralib", "colossalai>=0.2.4", "langchain",
- "tokenizers", "fastapi", "sse_starlette", "wandb", "sentencepiece", "gpustat"
- ]
- })
+ "torch==1.13.1",
+ "transformers>=4.20.1",
+ "datasets",
+ "loralib",
+ "colossalai>=0.2.4",
+ "langchain",
+ "tokenizers",
+ "fastapi",
+ "sse_starlette",
+ "wandb",
+ "sentencepiece",
+ "gpustat",
+ ],
+ },
+ )
if __name__ == "__main__":
diff --git a/applications/Chat/examples/community/ray/train_prompts_on_ray.py b/applications/Chat/examples/community/ray/train_prompts_on_ray.py
index 1bba9ad66fbc..8abd83a8b249 100644
--- a/applications/Chat/examples/community/ray/train_prompts_on_ray.py
+++ b/applications/Chat/examples/community/ray/train_prompts_on_ray.py
@@ -26,9 +26,14 @@
class ExperienceCompositionRefs:
-
- def __init__(self, sequences_attention_mask_action_mask_ref: ray.ObjectRef, action_log_probs_ref: ray.ObjectRef,
- base_action_log_probs_ref: ray.ObjectRef, value_ref: ray.ObjectRef, r_ref: ray.ObjectRef) -> None:
+ def __init__(
+ self,
+ sequences_attention_mask_action_mask_ref: ray.ObjectRef,
+ action_log_probs_ref: ray.ObjectRef,
+ base_action_log_probs_ref: ray.ObjectRef,
+ value_ref: ray.ObjectRef,
+ r_ref: ray.ObjectRef,
+ ) -> None:
self.sequences_attention_mask_action_mask_ref = sequences_attention_mask_action_mask_ref
self.action_log_probs_ref = action_log_probs_ref
self.base_action_log_probs_ref = base_action_log_probs_ref
@@ -37,14 +42,14 @@ def __init__(self, sequences_attention_mask_action_mask_ref: ray.ObjectRef, acti
class ExperienceMaker:
-
def __init__(self, kl_coef) -> None:
self.kl_coef = kl_coef
@torch.no_grad()
def make_experience(self, experiment_computation_refs: ExperienceCompositionRefs):
sequences, attention_mask, action_mask = ray.get(
- experiment_computation_refs.sequences_attention_mask_action_mask_ref)
+ experiment_computation_refs.sequences_attention_mask_action_mask_ref
+ )
action_log_probs = ray.get(experiment_computation_refs.action_log_probs_ref)
base_action_log_probs = ray.get(experiment_computation_refs.base_action_log_probs_ref)
r = ray.get(experiment_computation_refs.r_ref)
@@ -58,11 +63,10 @@ def make_experience(self, experiment_computation_refs: ExperienceCompositionRefs
class DistributedTorchRayActor:
-
def __init__(self, world_size, rank, local_rank, master_addr, master_port):
- logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s',
- level=logging.INFO,
- datefmt='%Y-%m-%d %H:%M:%S')
+ logging.basicConfig(
+ format="%(asctime)s %(levelname)-8s %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S"
+ )
self._model = None
self._world_size = world_size
self._rank = rank
@@ -82,7 +86,7 @@ def _get_current_node_ip():
@staticmethod
def _get_free_port():
with socket.socket() as sock:
- sock.bind(('', 0))
+ sock.bind(("", 0))
return sock.getsockname()[1]
def get_master_addr_port(self):
@@ -90,7 +94,6 @@ def get_master_addr_port(self):
class BasePPORole(DistributedTorchRayActor):
-
def add_experience_maker(self, kl_coef: float = 0.1):
self._experience_maker = ExperienceMaker(kl_coef)
@@ -99,12 +102,12 @@ def make_experience(self, experience_computation_ref: ExperienceCompositionRefs)
def _init_strategy(self, strategy: str):
# configure strategy
- if strategy == 'ddp':
+ if strategy == "ddp":
self._strategy = DDPStrategy()
- elif strategy == 'colossalai_gemini':
- self._strategy = GeminiStrategy(placement_policy='cuda', initial_scale=2**5)
- elif strategy == 'colossalai_zero2':
- self._strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
+ elif strategy == "colossalai_gemini":
+ self._strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5)
+ elif strategy == "colossalai_zero2":
+ self._strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
else:
raise ValueError(f'Unsupported strategy "{strategy}"')
@@ -124,11 +127,9 @@ def _prepare_model_with_strategy(self, has_optimizer: bool):
def _load_model_from_pretrained(self, model_class: Type[LoRAModule], pretrain: str):
raise NotImplementedError()
- def init_model_from_pretrained(self,
- strategy: str,
- model_class: Type[LoRAModule],
- pretrain: str,
- has_optimizer=False):
+ def init_model_from_pretrained(
+ self, strategy: str, model_class: Type[LoRAModule], pretrain: str, has_optimizer=False
+ ):
self._init_strategy(strategy)
self._load_model_from_pretrained(model_class, pretrain)
self._prepare_model_with_strategy(has_optimizer)
@@ -138,7 +139,6 @@ def eval(self):
class TrainablePPORole(BasePPORole):
-
def _load_model_from_pretrained(self, model_class, pretrain):
with self._strategy.model_init_context():
self._model = model_class(pretrain).to(torch.cuda.current_device())
@@ -161,38 +161,39 @@ def learn_on_experiences(self, experience_refs):
@ray.remote(num_gpus=1)
class RayPPOActor(TrainablePPORole):
-
def set_loss_function(self, eps_clip: float):
self._actor_loss_fn = PolicyLoss(eps_clip)
def load_tokenizer_from_pretrained(self, model_type: str, pretrained):
- if model_type == 'gpt2':
+ if model_type == "gpt2":
self._model_tokenizer = GPT2Tokenizer.from_pretrained(pretrained)
self._model_tokenizer.pad_token = self._model_tokenizer.eos_token
- elif model_type == 'bloom':
+ elif model_type == "bloom":
self._model_tokenizer = BloomTokenizerFast.from_pretrained(pretrained)
self._model_tokenizer.pad_token = self._model_tokenizer.eos_token
- elif model_type == 'opt':
+ elif model_type == "opt":
self._model_tokenizer = AutoTokenizer.from_pretrained(pretrained)
else:
raise ValueError(f'Unsupported model "{model_type}"')
# Set tokenize function for sequence generation
def _text_input_tokenize_fn(texts):
- batch = self._model_tokenizer(texts, return_tensors='pt', max_length=96, padding=True, truncation=True)
+ batch = self._model_tokenizer(texts, return_tensors="pt", max_length=96, padding=True, truncation=True)
return {k: v.cuda() for k, v in batch.items()}
self._sample_tokenize_function = _text_input_tokenize_fn
def setup_generate_kwargs(self, generate_kwargs: dict):
from coati.trainer.ppo import _set_default_generate_kwargs
+
self._generate_kwargs = _set_default_generate_kwargs(self._strategy, generate_kwargs, self._model)
- self._generate_kwargs['pad_token_id'] = self._model_tokenizer.pad_token_id
- self._generate_kwargs['eos_token_id'] = self._model_tokenizer.eos_token_id
+ self._generate_kwargs["pad_token_id"] = self._model_tokenizer.pad_token_id
+ self._generate_kwargs["eos_token_id"] = self._model_tokenizer.eos_token_id
def load_csv_prompt_file_from_url_to_sampler(self, prompt_url):
import pandas as pd
- prompts = pd.read_csv(prompt_url)['prompt']
+
+ prompts = pd.read_csv(prompt_url)["prompt"]
self._sampler = self._strategy.setup_sampler(prompts)
def _generate(self, input_ids, **generate_kwargs):
@@ -214,10 +215,9 @@ def calculate_action_log_probs(self, sequence_attention_action_mask):
def _training_step(self, experience):
num_actions = experience.action_mask.size(1)
action_log_probs = self._model(experience.sequences, num_actions, attention_mask=experience.attention_mask)
- actor_loss = self._actor_loss_fn(action_log_probs,
- experience.action_log_probs,
- experience.advantages,
- action_mask=experience.action_mask)
+ actor_loss = self._actor_loss_fn(
+ action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask
+ )
self._strategy.backward(actor_loss, self._model, self._optimizer)
self._strategy.optimizer_step(self._optimizer)
self._optimizer.zero_grad()
@@ -229,17 +229,18 @@ def save_checkpoint(self, save_path, should_save_optimizer: bool):
self._strategy.save_model(self._model, save_path, only_rank0=True)
# save optimizer checkpoint on all ranks
if should_save_optimizer:
- self._strategy.save_optimizer(self._optimizer,
- 'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
- only_rank0=False)
+ self._strategy.save_optimizer(
+ self._optimizer,
+ "actor_optim_checkpoint_prompts_%d.pt" % (torch.cuda.current_device()),
+ only_rank0=False,
+ )
def generate_answer(self, prompt, max_length=30, num_return_sequences=5):
- encoded_input = self._model_tokenizer(prompt, return_tensors='pt')
+ encoded_input = self._model_tokenizer(prompt, return_tensors="pt")
input_ids = {k: v.cuda() for k, v in encoded_input.items()}
- sequence, _ = self._model.generate(**input_ids,
- max_length=max_length,
- return_action_mask=False,
- num_return_sequences=num_return_sequences)
+ sequence, _ = self._model.generate(
+ **input_ids, max_length=max_length, return_action_mask=False, num_return_sequences=num_return_sequences
+ )
token_list = list(sequence.data[0])
output = " ".join([self._model_tokenizer.decode(token) for token in token_list])
return output
@@ -247,18 +248,16 @@ def generate_answer(self, prompt, max_length=30, num_return_sequences=5):
@ray.remote(num_gpus=1)
class RayPPOCritic(TrainablePPORole):
-
def set_loss_function(self, value_clip: float):
self._critic_loss_fn = ValueLoss(value_clip)
def _training_step(self, experience):
- values = self._model(experience.sequences,
- action_mask=experience.action_mask,
- attention_mask=experience.attention_mask)
- critic_loss = self._critic_loss_fn(values,
- experience.values,
- experience.reward,
- action_mask=experience.action_mask)
+ values = self._model(
+ experience.sequences, action_mask=experience.action_mask, attention_mask=experience.attention_mask
+ )
+ critic_loss = self._critic_loss_fn(
+ values, experience.values, experience.reward, action_mask=experience.action_mask
+ )
self._strategy.backward(critic_loss, self._model, self._optimizer)
self._strategy.optimizer_step(self._optimizer)
self._optimizer.zero_grad()
@@ -272,12 +271,12 @@ def calculate_value(self, sequence_attention_action_mask):
@ray.remote(num_gpus=1)
class RayPPORewardModel(BasePPORole):
-
def _load_model_from_pretrained(self, model_class, pretrain):
with self._strategy.model_init_context():
critic = model_class(pretrained=pretrain).to(torch.cuda.current_device())
- self._model = RewardModel(deepcopy(critic.model),
- deepcopy(critic.value_head)).to(torch.cuda.current_device())
+ self._model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).to(
+ torch.cuda.current_device()
+ )
@torch.no_grad()
def calculate_r(self, sequence_attention_action_mask):
@@ -287,7 +286,6 @@ def calculate_r(self, sequence_attention_action_mask):
@ray.remote(num_gpus=1)
class RayPPOInitialModel(BasePPORole):
-
def _load_model_from_pretrained(self, model_class, pretrain):
with self._strategy.model_init_context():
self._model = model_class(pretrain).to(torch.cuda.current_device())
@@ -300,8 +298,8 @@ def calculate_base_action_log_probs(self, sequence_attention_action_mask):
class PPORayActorGroup:
"""
- A group of ray actors
- Functions start with 'async' should return list of object refs
+ A group of ray actors
+ Functions start with 'async' should return list of object refs
"""
def __init__(self, num_nodes, num_gpus_per_node, ray_actor_type: Type[BasePPORole]) -> None:
@@ -319,8 +317,9 @@ def _initiate_actors(self):
pg = placement_group(bundles, strategy="STRICT_SPREAD")
ray.get(pg.ready())
if pg:
- master_actor = self.ray_actor_type.options(scheduling_strategy=PlacementGroupSchedulingStrategy(
- placement_group=pg, placement_group_bundle_index=0)).remote(world_size, 0, 0, None, None)
+ master_actor = self.ray_actor_type.options(
+ scheduling_strategy=PlacementGroupSchedulingStrategy(placement_group=pg, placement_group_bundle_index=0)
+ ).remote(world_size, 0, 0, None, None)
else:
master_actor = self.ray_actor_type.options(num_gpus=1).remote(world_size, 0, 0, None, None)
self._actor_handlers = [master_actor]
@@ -331,16 +330,20 @@ def _initiate_actors(self):
for rank in range(1, world_size):
local_rank = rank % self._num_gpus_per_node
if pg:
- worker_actor = self.ray_actor_type.options(scheduling_strategy=PlacementGroupSchedulingStrategy(
- placement_group=pg, placement_group_bundle_index=rank // self._num_gpus_per_node)).remote(
- world_size, rank, local_rank, master_addr, master_port)
+ worker_actor = self.ray_actor_type.options(
+ scheduling_strategy=PlacementGroupSchedulingStrategy(
+ placement_group=pg, placement_group_bundle_index=rank // self._num_gpus_per_node
+ )
+ ).remote(world_size, rank, local_rank, master_addr, master_port)
else:
- worker_actor = self.ray_actor_type.options(num_gpus=1).remote(world_size, rank, local_rank,
- master_addr, master_port)
+ worker_actor = self.ray_actor_type.options(num_gpus=1).remote(
+ world_size, rank, local_rank, master_addr, master_port
+ )
self._actor_handlers.append(worker_actor)
- def async_init_model_from_pretrained(self, strategy: str, model_class: Type[LoRAModule], pretrain: str,
- has_optimizer: bool):
+ def async_init_model_from_pretrained(
+ self, strategy: str, model_class: Type[LoRAModule], pretrain: str, has_optimizer: bool
+ ):
return [
actor.init_model_from_pretrained.remote(strategy, model_class, pretrain, has_optimizer)
for actor in self._actor_handlers
@@ -348,7 +351,6 @@ def async_init_model_from_pretrained(self, strategy: str, model_class: Type[LoRA
class TrainableModelRayActorGroup(PPORayActorGroup):
-
def async_learn_on_experiences(self, experience_refs):
num_actors = len(self._actor_handlers)
learn_result_refs = []
@@ -359,7 +361,6 @@ def async_learn_on_experiences(self, experience_refs):
class PPOActorRayActorGroup(TrainableModelRayActorGroup):
-
def __init__(self, num_nodes, num_gpus_per_node) -> None:
super().__init__(num_nodes, num_gpus_per_node, RayPPOActor)
@@ -381,7 +382,8 @@ def async_calculate_action_log_probs(self, sequences_attention_mask_action_mask_
action_log_probs_refs = []
for i in range(len(sequences_attention_mask_action_mask_refs)):
action_log_probs_ref = self._actor_handlers[i % num_actors].calculate_action_log_probs.remote(
- sequences_attention_mask_action_mask_refs[i])
+ sequences_attention_mask_action_mask_refs[i]
+ )
action_log_probs_refs.append(action_log_probs_ref)
return action_log_probs_refs
@@ -393,7 +395,6 @@ def save_checkpoint(self, save_path, should_save_optimizer):
class PPOCriticRayActorGroup(TrainableModelRayActorGroup):
-
def __init__(self, num_nodes, num_gpus_per_node) -> None:
super().__init__(num_nodes, num_gpus_per_node, RayPPOCritic)
@@ -402,7 +403,8 @@ def async_calculate_value(self, sequences_attention_mask_action_mask_refs):
value_refs = []
for i in range(len(sequences_attention_mask_action_mask_refs)):
value_ref = self._actor_handlers[i % num_actors].calculate_value.remote(
- sequences_attention_mask_action_mask_refs[i])
+ sequences_attention_mask_action_mask_refs[i]
+ )
value_refs.append(value_ref)
return value_refs
@@ -411,7 +413,6 @@ def set_loss_function(self, value_clip: float = 0.4):
class PPOInitialRayActorGroup(PPORayActorGroup):
-
def __init__(self, num_nodes, num_gpus_per_node) -> None:
super().__init__(num_nodes, num_gpus_per_node, RayPPOInitialModel)
@@ -420,13 +421,13 @@ def async_calculate_base_action_log_probs(self, sequences_attention_mask_action_
base_action_log_probs_refs = []
for i in range(len(sequences_attention_mask_action_mask_refs)):
base_action_log_probs_ref = self._actor_handlers[i % num_actors].calculate_base_action_log_probs.remote(
- sequences_attention_mask_action_mask_refs[i])
+ sequences_attention_mask_action_mask_refs[i]
+ )
base_action_log_probs_refs.append(base_action_log_probs_ref)
return base_action_log_probs_refs
class PPORewardRayActorGroup(PPORayActorGroup):
-
def __init__(self, num_nodes, num_gpus_per_node) -> None:
super().__init__(num_nodes, num_gpus_per_node, RayPPORewardModel)
@@ -435,20 +436,21 @@ def async_calculate_r(self, sequences_attention_mask_action_mask_refs):
r_refs = []
for i in range(len(sequences_attention_mask_action_mask_refs)):
r_ref = self._actor_handlers[i % num_actors].calculate_r.remote(
- sequences_attention_mask_action_mask_refs[i])
+ sequences_attention_mask_action_mask_refs[i]
+ )
r_refs.append(r_ref)
return r_refs
def main(args):
- logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s',
- level=logging.INFO,
- datefmt='%Y-%m-%d %H:%M:%S')
- if args.model == 'gpt2':
+ logging.basicConfig(
+ format="%(asctime)s %(levelname)-8s %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S"
+ )
+ if args.model == "gpt2":
actor_model_class, critic_model_class = GPTActor, GPTCritic
- elif args.model == 'bloom':
+ elif args.model == "bloom":
actor_model_class, critic_model_class = BLOOMActor, BLOOMCritic
- elif args.model == 'opt':
+ elif args.model == "opt":
actor_model_class, critic_model_class = OPTActor, OPTCritic
else:
raise ValueError(f'Unsupported model "{args.model}"')
@@ -462,13 +464,14 @@ def main(args):
logging.info("Actors created")
# Prepare model for training
- generate_kwargs = {'max_length': 128, 'do_sample': True, 'temperature': 1.0, 'top_k': 50}
+ generate_kwargs = {"max_length": 128, "do_sample": True, "temperature": 1.0, "top_k": 50}
ray.get(
- actor_group.async_init_model_from_pretrained(args.strategy, actor_model_class, args.pretrain, True) +
- critic_group.async_init_model_from_pretrained(args.strategy, critic_model_class, args.pretrain, True) +
- initial_group.async_init_model_from_pretrained(args.strategy, actor_model_class, args.pretrain, False) +
- reward_group.async_init_model_from_pretrained(args.strategy, critic_model_class, args.pretrain, False) +
- actor_group.async_prepare_for_sequence_generation(args.model, args.pretrain, generate_kwargs))
+ actor_group.async_init_model_from_pretrained(args.strategy, actor_model_class, args.pretrain, True)
+ + critic_group.async_init_model_from_pretrained(args.strategy, critic_model_class, args.pretrain, True)
+ + initial_group.async_init_model_from_pretrained(args.strategy, actor_model_class, args.pretrain, False)
+ + reward_group.async_init_model_from_pretrained(args.strategy, critic_model_class, args.pretrain, False)
+ + actor_group.async_prepare_for_sequence_generation(args.model, args.pretrain, generate_kwargs)
+ )
logging.info("Models prepared for training")
# Prepare models for training
@@ -483,8 +486,12 @@ def main(args):
# Start training
logging.info("Training start")
# Set all models to eval and add experience maker
- all_ray_actors = actor_group._actor_handlers + critic_group._actor_handlers + \
- initial_group._actor_handlers + reward_group._actor_handlers
+ all_ray_actors = (
+ actor_group._actor_handlers
+ + critic_group._actor_handlers
+ + initial_group._actor_handlers
+ + reward_group._actor_handlers
+ )
num_ray_actors = len(all_ray_actors)
ray.get([ray_actor.eval.remote() for ray_actor in all_ray_actors])
ray.get([ray_actor.add_experience_maker.remote() for ray_actor in all_ray_actors])
@@ -497,18 +504,28 @@ def main(args):
time += 1
# Experience queueing stage
sequences_attention_mask_action_mask_refs = actor_group.async_sample_prompts_and_make_sequence(
- experience_batch_size)
+ experience_batch_size
+ )
base_action_log_probs_refs = initial_group.async_calculate_base_action_log_probs(
- sequences_attention_mask_action_mask_refs)
+ sequences_attention_mask_action_mask_refs
+ )
values_refs = critic_group.async_calculate_value(sequences_attention_mask_action_mask_refs)
r_refs = reward_group.async_calculate_r(sequences_attention_mask_action_mask_refs)
action_log_probs_refs = actor_group.async_calculate_action_log_probs(
- sequences_attention_mask_action_mask_refs)
- experience_composition_refs.extend([
- ExperienceCompositionRefs(sequences_attention_mask_action_mask_refs[i], action_log_probs_refs[i],
- base_action_log_probs_refs[i], values_refs[i], r_refs[i])
- for i in range(len(sequences_attention_mask_action_mask_refs))
- ])
+ sequences_attention_mask_action_mask_refs
+ )
+ experience_composition_refs.extend(
+ [
+ ExperienceCompositionRefs(
+ sequences_attention_mask_action_mask_refs[i],
+ action_log_probs_refs[i],
+ base_action_log_probs_refs[i],
+ values_refs[i],
+ r_refs[i],
+ )
+ for i in range(len(sequences_attention_mask_action_mask_refs))
+ ]
+ )
# Learning stage
if time % update_timesteps == 0:
experience_refs = []
@@ -519,8 +536,9 @@ def main(args):
experience_refs.append(selected_ray_actor.make_experience.remote(exp_composition_ref))
# backward
ray.get(
- actor_group.async_learn_on_experiences(experience_refs) +
- critic_group.async_learn_on_experiences(experience_refs))
+ actor_group.async_learn_on_experiences(experience_refs)
+ + critic_group.async_learn_on_experiences(experience_refs)
+ )
# clear refs queue
experience_composition_refs.clear()
logging.info("Training finished")
@@ -528,26 +546,24 @@ def main(args):
actor_group.save_checkpoint(args.save_path, args.need_optim_ckpt)
-if __name__ == '__main__':
+if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('--prompt_csv_url', type=str)
- parser.add_argument('--strategy',
- choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'],
- default='ddp')
- parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt'])
- parser.add_argument('--pretrain', type=str, default='gpt2')
- parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts.pt')
- parser.add_argument('--need_optim_ckpt', type=bool, default=False)
- parser.add_argument('--num_episodes', type=int, default=10)
- parser.add_argument('--max_timesteps', type=int, default=10)
- parser.add_argument('--update_timesteps', type=int, default=10)
- parser.add_argument('--train_batch_size', type=int, default=8)
- parser.add_argument('--experience_batch_size', type=int, default=8)
- parser.add_argument('--num_actor_nodes', type=int, help='num of nodes to use to host actor model', default=1)
- parser.add_argument('--num_critic_nodes', type=int, help='num of nodes to use to host critic model', default=1)
- parser.add_argument('--num_initial_nodes', type=int, help='num of nodes to use to host initial model', default=1)
- parser.add_argument('--num_reward_nodes', type=int, help='num of nodes to use to host reward model', default=1)
- parser.add_argument('--num_gpus_per_node', type=int, help='num of gpus on a ray node', default=1)
+ parser.add_argument("--prompt_csv_url", type=str)
+ parser.add_argument("--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="ddp")
+ parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt"])
+ parser.add_argument("--pretrain", type=str, default="gpt2")
+ parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts.pt")
+ parser.add_argument("--need_optim_ckpt", type=bool, default=False)
+ parser.add_argument("--num_episodes", type=int, default=10)
+ parser.add_argument("--max_timesteps", type=int, default=10)
+ parser.add_argument("--update_timesteps", type=int, default=10)
+ parser.add_argument("--train_batch_size", type=int, default=8)
+ parser.add_argument("--experience_batch_size", type=int, default=8)
+ parser.add_argument("--num_actor_nodes", type=int, help="num of nodes to use to host actor model", default=1)
+ parser.add_argument("--num_critic_nodes", type=int, help="num of nodes to use to host critic model", default=1)
+ parser.add_argument("--num_initial_nodes", type=int, help="num of nodes to use to host initial model", default=1)
+ parser.add_argument("--num_reward_nodes", type=int, help="num of nodes to use to host reward model", default=1)
+ parser.add_argument("--num_gpus_per_node", type=int, help="num of gpus on a ray node", default=1)
args = parser.parse_args()
ray.init()
main(args)
diff --git a/applications/Chat/examples/download_model.py b/applications/Chat/examples/download_model.py
index c2b5f9a859a9..ec3482b5f789 100644
--- a/applications/Chat/examples/download_model.py
+++ b/applications/Chat/examples/download_model.py
@@ -22,7 +22,7 @@ def download(self, dir_path: str):
file_path = hf_hub_download(self.repo_id, file, local_dir=dir_path)
def download_all(self):
- file_path = snapshot_download(self.repo_id)
+ snapshot_download(self.repo_id)
def test_init(model: str, dir_path: str):
@@ -31,19 +31,19 @@ def test_init(model: str, dir_path: str):
actor = GPTActor(config=config)
critic = GPTCritic(config=config)
reward_model = GPTRM(config=config)
- tokenizer = GPT2Tokenizer.from_pretrained(dir_path)
+ GPT2Tokenizer.from_pretrained(dir_path)
elif model == "bloom":
config = BloomConfig.from_pretrained(dir_path)
actor = BLOOMActor(config=config)
critic = BLOOMCritic(config=config)
reward_model = BLOOMRM(config=config)
- tokenizer = BloomTokenizerFast.from_pretrained(dir_path)
+ BloomTokenizerFast.from_pretrained(dir_path)
elif model == "opt":
config = AutoConfig.from_pretrained(dir_path)
actor = OPTActor(config=config)
critic = OPTCritic(config=config)
reward_model = OPTRM(config=config)
- tokenizer = AutoTokenizer.from_pretrained(dir_path)
+ AutoTokenizer.from_pretrained(dir_path)
else:
raise NotImplementedError(f"Model {model} not implemented")
@@ -59,17 +59,12 @@ def test_init(model: str, dir_path: str):
exit(0)
repo_list = {
- "gpt2": HFRepoFiles(
- repo_id="gpt2",
- files=["config.json", "tokenizer.json", "vocab.json", "merges.txt"]
- ),
+ "gpt2": HFRepoFiles(repo_id="gpt2", files=["config.json", "tokenizer.json", "vocab.json", "merges.txt"]),
"bloom": HFRepoFiles(
- repo_id="bigscience/bloom-560m",
- files=["config.json", "tokenizer.json", "tokenizer_config.json"]
+ repo_id="bigscience/bloom-560m", files=["config.json", "tokenizer.json", "tokenizer_config.json"]
),
"opt": HFRepoFiles(
- repo_id="facebook/opt-350m",
- files=["config.json", "tokenizer_config.json", "vocab.json", "merges.txt"]
+ repo_id="facebook/opt-350m", files=["config.json", "tokenizer_config.json", "vocab.json", "merges.txt"]
),
}
diff --git a/applications/Chat/examples/generate_conversation_dataset.py b/applications/Chat/examples/generate_conversation_dataset.py
index 8d2fbba955b8..7e03b2d54260 100644
--- a/applications/Chat/examples/generate_conversation_dataset.py
+++ b/applications/Chat/examples/generate_conversation_dataset.py
@@ -31,9 +31,11 @@ def generate_alpaca():
def generate_sharegpt():
# ShareGPT data requires less processing.
conversation_dataset = []
- dataset = load_dataset("anon8231489123/ShareGPT_Vicuna_unfiltered",
- data_files="ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json",
- split="train")
+ dataset = load_dataset(
+ "anon8231489123/ShareGPT_Vicuna_unfiltered",
+ data_files="ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json",
+ split="train",
+ )
conversations = dataset["conversations"]
@@ -43,23 +45,24 @@ def generate_sharegpt():
del conv["markdown"]
del conv["text"]
- conversation = dict(type="conversation",
- language="Multilingual",
- dataset="ShareGPT",
- conversations=conversations[idx])
+ conversation = dict(
+ type="conversation", language="Multilingual", dataset="ShareGPT", conversations=conversations[idx]
+ )
conversation_dataset.append(conversation)
return conversation_dataset
-if __name__ == '__main__':
+if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('--dataset',
- type=str,
- default="All",
- choices=["Alpaca", "ShareGPT", "All"],
- help="which dataset to convert, All will combine Alpaca and ShareGPT")
- parser.add_argument('--save_path', type=str, default="dataset.json", help="path to save the converted dataset")
+ parser.add_argument(
+ "--dataset",
+ type=str,
+ default="All",
+ choices=["Alpaca", "ShareGPT", "All"],
+ help="which dataset to convert, All will combine Alpaca and ShareGPT",
+ )
+ parser.add_argument("--save_path", type=str, default="dataset.json", help="path to save the converted dataset")
args = parser.parse_args()
conversation_dataset = []
@@ -75,5 +78,5 @@ def generate_sharegpt():
for idx, sample in enumerate(conversation_dataset):
sample["id"] = idx + 1
- with open(args.save_path, mode='w') as f:
+ with open(args.save_path, mode="w") as f:
json.dump(conversation_dataset, f, indent=4, default=str, ensure_ascii=False)
diff --git a/applications/Chat/examples/generate_prompt_dataset.py b/applications/Chat/examples/generate_prompt_dataset.py
index 2abb31c09f82..4eec6feae505 100644
--- a/applications/Chat/examples/generate_prompt_dataset.py
+++ b/applications/Chat/examples/generate_prompt_dataset.py
@@ -6,7 +6,7 @@
def sample(args):
- with open(args.dataset_path, mode='r') as f:
+ with open(args.dataset_path, mode="r") as f:
dataset_list = json.load(f)
sampled_dataset = [
@@ -14,18 +14,14 @@ def sample(args):
for idx, sample in enumerate(random.sample(dataset_list, args.sample_size))
]
- with open(args.save_path, mode='w') as f:
- json.dump(sampled_dataset, f, indent=4,
- default=str, ensure_ascii=False)
+ with open(args.save_path, mode="w") as f:
+ json.dump(sampled_dataset, f, indent=4, default=str, ensure_ascii=False)
-if __name__ == '__main__':
+if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('--dataset_path', type=str, default=None,
- required=True, help="path to the pretrain dataset")
- parser.add_argument('--save_path', type=str, default='prompt.json',
- help="path to save the prompt dataset")
- parser.add_argument('--sample_size', type=int,
- default=16384, help="size of the prompt dataset")
+ parser.add_argument("--dataset_path", type=str, default=None, required=True, help="path to the pretrain dataset")
+ parser.add_argument("--save_path", type=str, default="prompt.json", help="path to save the prompt dataset")
+ parser.add_argument("--sample_size", type=int, default=16384, help="size of the prompt dataset")
args = parser.parse_args()
sample(args)
diff --git a/applications/Chat/examples/inference.py b/applications/Chat/examples/inference.py
index e1e57e3cd376..62e06bf7b3bb 100644
--- a/applications/Chat/examples/inference.py
+++ b/applications/Chat/examples/inference.py
@@ -11,13 +11,13 @@
def eval(args):
# configure model
- if args.model == 'gpt2':
+ if args.model == "gpt2":
actor = GPTActor(pretrained=args.pretrain)
- elif args.model == 'bloom':
+ elif args.model == "bloom":
actor = BLOOMActor(pretrained=args.pretrain)
- elif args.model == 'opt':
+ elif args.model == "opt":
actor = OPTActor(pretrained=args.pretrain)
- elif args.model == 'llama':
+ elif args.model == "llama":
actor = LlamaActor(pretrained=args.pretrain)
else:
raise ValueError(f'Unsupported model "{args.model}"')
@@ -28,45 +28,46 @@ def eval(args):
actor.load_state_dict(state_dict)
# configure tokenizer
- if args.model == 'gpt2':
- tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
+ if args.model == "gpt2":
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'bloom':
- tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m')
+ elif args.model == "bloom":
+ tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'opt':
+ elif args.model == "opt":
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'llama':
+ elif args.model == "llama":
tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
- tokenizer.eos_token = '<\s>'
+ tokenizer.eos_token = "<\s>"
tokenizer.pad_token = tokenizer.unk_token
else:
raise ValueError(f'Unsupported model "{args.model}"')
actor.eval()
- input_ids = tokenizer.encode(args.input,
- return_tensors='pt')\
- .to(torch.cuda.current_device())
- outputs = generate(actor,
- input_ids,
- max_length=args.max_length,
- do_sample=True,
- top_k=50,
- top_p=0.95,
- num_return_sequences=1)
- output = tokenizer.batch_decode(outputs[0],
- skip_special_tokens=True)
+ tokenizer.padding_side = "left"
+ input_ids = tokenizer.encode(args.input, return_tensors="pt").to(torch.cuda.current_device())
+ outputs = generate(
+ actor,
+ input_ids,
+ tokenizer=tokenizer,
+ max_length=args.max_length,
+ do_sample=True,
+ top_k=50,
+ top_p=0.95,
+ num_return_sequences=1,
+ )
+ output = tokenizer.batch_decode(outputs[0], skip_special_tokens=True)
print(f"[Output]: {''.join(output)}")
-if __name__ == '__main__':
+if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
+ parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
# We suggest to use the pretrained model from HuggingFace, use pretrain to configure model
- parser.add_argument('--pretrain', type=str, default=None)
- parser.add_argument('--model_path', type=str, default=None)
- parser.add_argument('--input', type=str, default='Question: How are you ? Answer:')
- parser.add_argument('--max_length', type=int, default=100)
+ parser.add_argument("--pretrain", type=str, default=None)
+ parser.add_argument("--model_path", type=str, default=None)
+ parser.add_argument("--input", type=str, default="Question: How are you ? Answer:")
+ parser.add_argument("--max_length", type=int, default=100)
args = parser.parse_args()
eval(args)
diff --git a/applications/Chat/examples/ray/1mmt_prompt.py b/applications/Chat/examples/ray/1mmt_prompt.py
index 5dd52f1790e6..8de6219ec4e9 100644
--- a/applications/Chat/examples/ray/1mmt_prompt.py
+++ b/applications/Chat/examples/ray/1mmt_prompt.py
@@ -5,7 +5,6 @@
import pandas as pd
import ray
-import torch
from coati.quant import llama_load_quant, low_resource_init
from coati.ray.detached_trainer_ppo import DetachedPPOTrainer
from coati.ray.experience_maker_holder import ExperienceMakerHolder
@@ -23,13 +22,13 @@
def get_free_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
- s.bind(('', 0))
+ s.bind(("", 0))
return s.getsockname()[1]
def get_local_ip():
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
- s.connect(('8.8.8.8', 80))
+ s.connect(("8.8.8.8", 80))
return s.getsockname()[0]
@@ -37,22 +36,25 @@ def main(args):
master_addr = str(get_local_ip())
# trainer_env_info
trainer_port = str(get_free_port())
- env_info_trainers = [{
- 'local_rank': '0',
- 'rank': str(rank),
- 'world_size': str(args.num_trainers),
- 'master_port': trainer_port,
- 'master_addr': master_addr
- } for rank in range(args.num_trainers)]
+ env_info_trainers = [
+ {
+ "local_rank": "0",
+ "rank": str(rank),
+ "world_size": str(args.num_trainers),
+ "master_port": trainer_port,
+ "master_addr": master_addr,
+ }
+ for rank in range(args.num_trainers)
+ ]
# maker_env_info
maker_port = str(get_free_port())
env_info_maker = {
- 'local_rank': '0',
- 'rank': '0',
- 'world_size': '1',
- 'master_port': maker_port,
- 'master_addr': master_addr
+ "local_rank": "0",
+ "rank": "0",
+ "world_size": "1",
+ "master_port": maker_port,
+ "master_addr": master_addr,
}
# configure tokenizer
@@ -75,27 +77,33 @@ def trainer_model_fn():
eval_performance=True,
debug=args.debug,
update_lora_weights=not (args.lora_rank == 0),
- ) for i, env_info_trainer in enumerate(env_info_trainers)
+ )
+ for i, env_info_trainer in enumerate(env_info_trainers)
]
def model_fn():
actor = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda()
critic = get_critic_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda()
reward_model = get_reward_model_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda()
- if args.initial_model_quant_ckpt is not None and args.model == 'llama':
+ if args.initial_model_quant_ckpt is not None and args.model == "llama":
# quantize initial model
actor_cfg = AutoConfig.from_pretrained(args.pretrain)
with low_resource_init(), no_init_weights():
initial_model = get_actor_from_args(args.model, config=actor_cfg)
- initial_model.model = llama_load_quant(initial_model.model, args.initial_model_quant_ckpt, args.quant_bits,
- args.quant_group_size).cuda().requires_grad_(False)
+ initial_model.model = (
+ llama_load_quant(
+ initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, args.quant_group_size
+ )
+ .cuda()
+ .requires_grad_(False)
+ )
else:
initial_model = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda()
return actor, critic, reward_model, initial_model
# configure Experience Maker
experience_holder_ref = ExperienceMakerHolder.options(name="maker1", num_gpus=1, max_concurrency=2).remote(
- detached_trainer_name_list=[f'trainer{i}' for i in range(args.num_trainers)],
+ detached_trainer_name_list=[f"trainer{i}" for i in range(args.num_trainers)],
strategy_fn=partial(get_strategy_from_args, args.maker_strategy),
model_fn=model_fn,
env_info=env_info_maker,
@@ -130,12 +138,11 @@ def model_fn():
dataset_size = args.experience_batch_size * 4
def build_dataloader():
-
def tokenize_fn(texts):
- batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True)
+ batch = tokenizer(texts, return_tensors="pt", max_length=96, padding="max_length", truncation=True)
return {k: v.cuda() for k, v in batch.items()}
- dataset = pd.read_csv(args.prompt_path)['prompt']
+ dataset = pd.read_csv(args.prompt_path)["prompt"]
dataloader = DataLoader(dataset=dataset, batch_size=dataset_size, shuffle=True, collate_fn=tokenize_fn)
return dataloader
@@ -144,32 +151,31 @@ def tokenize_fn(texts):
ray.get(wait_tasks)
-if __name__ == '__main__':
+if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('--prompt_path', type=str, default=None)
- parser.add_argument('--num_trainers', type=int, default=1)
- parser.add_argument('--trainer_strategy',
- choices=[
- 'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu',
- 'colossalai_zero2_cpu'
- ],
- default='ddp')
- parser.add_argument('--maker_strategy', choices=['naive'], default='naive')
- parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
- parser.add_argument('--critic_model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
- parser.add_argument('--pretrain', type=str, default=None)
- parser.add_argument('--critic_pretrain', type=str, default=None)
- parser.add_argument('--experience_steps', type=int, default=4)
- parser.add_argument('--experience_batch_size', type=int, default=8)
- parser.add_argument('--train_epochs', type=int, default=1)
- parser.add_argument('--update_steps', type=int, default=2)
- parser.add_argument('--train_batch_size', type=int, default=8)
- parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
-
- parser.add_argument('--initial_model_quant_ckpt', type=str, default=None)
- parser.add_argument('--quant_bits', type=int, default=4)
- parser.add_argument('--quant_group_size', type=int, default=128)
- parser.add_argument('--debug', action='store_true')
+ parser.add_argument("--prompt_path", type=str, default=None)
+ parser.add_argument("--num_trainers", type=int, default=1)
+ parser.add_argument(
+ "--trainer_strategy",
+ choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_gemini_cpu", "colossalai_zero2_cpu"],
+ default="ddp",
+ )
+ parser.add_argument("--maker_strategy", choices=["naive"], default="naive")
+ parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
+ parser.add_argument("--critic_model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
+ parser.add_argument("--pretrain", type=str, default=None)
+ parser.add_argument("--critic_pretrain", type=str, default=None)
+ parser.add_argument("--experience_steps", type=int, default=4)
+ parser.add_argument("--experience_batch_size", type=int, default=8)
+ parser.add_argument("--train_epochs", type=int, default=1)
+ parser.add_argument("--update_steps", type=int, default=2)
+ parser.add_argument("--train_batch_size", type=int, default=8)
+ parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
+
+ parser.add_argument("--initial_model_quant_ckpt", type=str, default=None)
+ parser.add_argument("--quant_bits", type=int, default=4)
+ parser.add_argument("--quant_group_size", type=int, default=128)
+ parser.add_argument("--debug", action="store_true")
args = parser.parse_args()
ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)})
main(args)
diff --git a/applications/Chat/examples/ray/mmmt_prompt.py b/applications/Chat/examples/ray/mmmt_prompt.py
index 76929c9d0144..7c03a0468b02 100644
--- a/applications/Chat/examples/ray/mmmt_prompt.py
+++ b/applications/Chat/examples/ray/mmmt_prompt.py
@@ -5,7 +5,6 @@
import pandas as pd
import ray
-import torch
from coati.quant import llama_load_quant, low_resource_init
from coati.ray.detached_trainer_ppo import DetachedPPOTrainer
from coati.ray.experience_maker_holder import ExperienceMakerHolder
@@ -23,13 +22,13 @@
def get_free_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
- s.bind(('', 0))
+ s.bind(("", 0))
return s.getsockname()[1]
def get_local_ip():
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
- s.connect(('8.8.8.8', 80))
+ s.connect(("8.8.8.8", 80))
return s.getsockname()[0]
@@ -37,23 +36,29 @@ def main(args):
master_addr = str(get_local_ip())
# trainer_env_info
trainer_port = str(get_free_port())
- env_info_trainers = [{
- 'local_rank': '0',
- 'rank': str(rank),
- 'world_size': str(args.num_trainers),
- 'master_port': trainer_port,
- 'master_addr': master_addr
- } for rank in range(args.num_trainers)]
+ env_info_trainers = [
+ {
+ "local_rank": "0",
+ "rank": str(rank),
+ "world_size": str(args.num_trainers),
+ "master_port": trainer_port,
+ "master_addr": master_addr,
+ }
+ for rank in range(args.num_trainers)
+ ]
# maker_env_info
maker_port = str(get_free_port())
- env_info_makers = [{
- 'local_rank': '0',
- 'rank': str(rank),
- 'world_size': str(args.num_makers),
- 'master_port': maker_port,
- 'master_addr': master_addr
- } for rank in range(args.num_makers)]
+ env_info_makers = [
+ {
+ "local_rank": "0",
+ "rank": str(rank),
+ "world_size": str(args.num_makers),
+ "master_port": maker_port,
+ "master_addr": master_addr,
+ }
+ for rank in range(args.num_makers)
+ ]
# configure tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.pretrain)
@@ -63,13 +68,18 @@ def model_fn():
actor = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda()
critic = get_critic_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda()
reward_model = get_reward_model_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda()
- if args.initial_model_quant_ckpt is not None and args.model == 'llama':
+ if args.initial_model_quant_ckpt is not None and args.model == "llama":
# quantize initial model
actor_cfg = AutoConfig.from_pretrained(args.pretrain)
with low_resource_init(), no_init_weights():
initial_model = get_actor_from_args(args.model, config=actor_cfg)
- initial_model.model = llama_load_quant(initial_model.model, args.initial_model_quant_ckpt, args.quant_bits,
- args.quant_group_size).cuda().requires_grad_(False)
+ initial_model.model = (
+ llama_load_quant(
+ initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, args.quant_group_size
+ )
+ .cuda()
+ .requires_grad_(False)
+ )
else:
initial_model = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda()
return actor, critic, reward_model, initial_model
@@ -78,7 +88,7 @@ def model_fn():
experience_holder_refs = [
ExperienceMakerHolder.options(name=f"maker{i}", num_gpus=1, max_concurrency=2).remote(
detached_trainer_name_list=[
- f'trainer{x}'
+ f"trainer{x}"
for x in get_receivers_per_sender(i, args.num_makers, args.num_trainers, allow_idle_sender=False)
],
strategy_fn=partial(get_strategy_from_args, args.maker_strategy),
@@ -87,8 +97,8 @@ def model_fn():
kl_coef=0.1,
debug=args.debug,
update_lora_weights=not (args.lora_rank == 0),
- # sync_models_from_trainers=True,
- # generation kwargs:
+ # sync_models_from_trainers=True,
+ # generation kwargs:
max_length=512,
do_sample=True,
temperature=1.0,
@@ -128,12 +138,11 @@ def trainer_model_fn():
dataset_size = args.experience_batch_size * 4
def build_dataloader():
-
def tokenize_fn(texts):
- batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True)
+ batch = tokenizer(texts, return_tensors="pt", max_length=96, padding="max_length", truncation=True)
return {k: v.cuda() for k, v in batch.items()}
- dataset = pd.read_csv(args.prompt_path)['prompt']
+ dataset = pd.read_csv(args.prompt_path)["prompt"]
dataloader = DataLoader(dataset=dataset, batch_size=dataset_size, shuffle=True, collate_fn=tokenize_fn)
return dataloader
@@ -148,39 +157,44 @@ def tokenize_fn(texts):
for experience_holder_ref in experience_holder_refs:
wait_tasks.append(experience_holder_ref.workingloop.remote(build_dataloader, num_steps=args.experience_steps))
- total_steps = args.experience_batch_size * args.experience_steps * \
- args.num_makers // (args.num_trainers * args.train_batch_size)
+ total_steps = (
+ args.experience_batch_size
+ * args.experience_steps
+ * args.num_makers
+ // (args.num_trainers * args.train_batch_size)
+ )
for trainer_ref in trainer_refs:
wait_tasks.append(trainer_ref.fit.remote(total_steps, args.update_steps, args.train_epochs))
ray.get(wait_tasks)
-if __name__ == '__main__':
+if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('--prompt_path', type=str, default=None)
- parser.add_argument('--num_makers', type=int, default=1)
- parser.add_argument('--num_trainers', type=int, default=1)
+ parser.add_argument("--prompt_path", type=str, default=None)
+ parser.add_argument("--num_makers", type=int, default=1)
+ parser.add_argument("--num_trainers", type=int, default=1)
parser.add_argument(
- '--trainer_strategy',
- choices=['ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu', 'colossalai_zero2_cpu'],
- default='ddp')
- parser.add_argument('--maker_strategy', choices=['naive'], default='naive')
- parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
- parser.add_argument('--critic_model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
- parser.add_argument('--pretrain', type=str, default=None)
- parser.add_argument('--critic_pretrain', type=str, default=None)
- parser.add_argument('--experience_steps', type=int, default=4)
- parser.add_argument('--experience_batch_size', type=int, default=8)
- parser.add_argument('--train_epochs', type=int, default=1)
- parser.add_argument('--update_steps', type=int, default=2)
- parser.add_argument('--train_batch_size', type=int, default=8)
- parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
-
- parser.add_argument('--initial_model_quant_ckpt', type=str, default=None)
- parser.add_argument('--quant_bits', type=int, default=4)
- parser.add_argument('--quant_group_size', type=int, default=128)
- parser.add_argument('--debug', action='store_true')
+ "--trainer_strategy",
+ choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_gemini_cpu", "colossalai_zero2_cpu"],
+ default="ddp",
+ )
+ parser.add_argument("--maker_strategy", choices=["naive"], default="naive")
+ parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
+ parser.add_argument("--critic_model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
+ parser.add_argument("--pretrain", type=str, default=None)
+ parser.add_argument("--critic_pretrain", type=str, default=None)
+ parser.add_argument("--experience_steps", type=int, default=4)
+ parser.add_argument("--experience_batch_size", type=int, default=8)
+ parser.add_argument("--train_epochs", type=int, default=1)
+ parser.add_argument("--update_steps", type=int, default=2)
+ parser.add_argument("--train_batch_size", type=int, default=8)
+ parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
+
+ parser.add_argument("--initial_model_quant_ckpt", type=str, default=None)
+ parser.add_argument("--quant_bits", type=int, default=4)
+ parser.add_argument("--quant_group_size", type=int, default=128)
+ parser.add_argument("--debug", action="store_true")
args = parser.parse_args()
ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)})
diff --git a/applications/Chat/examples/requirements.txt b/applications/Chat/examples/requirements.txt
index 5d0f9f927d17..5474dfa16b3e 100644
--- a/applications/Chat/examples/requirements.txt
+++ b/applications/Chat/examples/requirements.txt
@@ -1,3 +1,3 @@
pandas>=1.4.1
sentencepiece
-colossalai==0.3.1
\ No newline at end of file
+colossalai==0.3.3
diff --git a/applications/Chat/examples/train_prompts.py b/applications/Chat/examples/train_prompts.py
index d27a70a3fef6..8868e278d85e 100644
--- a/applications/Chat/examples/train_prompts.py
+++ b/applications/Chat/examples/train_prompts.py
@@ -20,28 +20,32 @@
def main(args):
# configure strategy
- if args.strategy == 'ddp':
+ if args.strategy == "ddp":
strategy = DDPStrategy()
- elif args.strategy == 'colossalai_gemini':
- strategy = GeminiStrategy(placement_policy='cuda', initial_scale=2**5)
- elif args.strategy == 'colossalai_zero2':
- strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
+ elif args.strategy == "colossalai_gemini":
+ strategy = GeminiStrategy(placement_policy="static", initial_scale=2**5)
+ elif args.strategy == "colossalai_zero2":
+ strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
else:
raise ValueError(f'Unsupported strategy "{args.strategy}"')
if args.rm_path is not None:
- warnings.warn('LoRA weights should be merged with the model weights')
- state_dict = torch.load(args.rm_path, map_location='cpu')
+ warnings.warn("LoRA weights should be merged with the model weights")
+ state_dict = torch.load(args.rm_path, map_location="cpu")
+
+ if args.lora_rank > 0:
+ warnings.warn("Lora is not supported yet.")
+ args.lora_rank = 0
with strategy.model_init_context():
# configure model
- if args.model == 'gpt2':
+ if args.model == "gpt2":
initial_model = GPTActor(pretrained=args.pretrain)
- elif args.model == 'bloom':
+ elif args.model == "bloom":
initial_model = BLOOMActor(pretrained=args.pretrain)
- elif args.model == 'opt':
+ elif args.model == "opt":
initial_model = OPTActor(pretrained=args.pretrain)
- elif args.model == 'llama':
+ elif args.model == "llama":
initial_model = LlamaActor(pretrained=args.pretrain)
else:
raise ValueError(f'Unsupported actor model "{args.model}"')
@@ -51,13 +55,13 @@ def main(args):
else:
rm_model_name = args.rm_model
- if rm_model_name == 'gpt2':
+ if rm_model_name == "gpt2":
reward_model = GPTRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
- elif rm_model_name == 'bloom':
+ elif rm_model_name == "bloom":
reward_model = BLOOMRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
- elif rm_model_name == 'opt':
+ elif rm_model_name == "opt":
reward_model = OPTRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
- elif rm_model_name == 'llama':
+ elif rm_model_name == "llama":
reward_model = LlamaRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
else:
raise ValueError(f'Unsupported reward model "{rm_model_name}"')
@@ -65,28 +69,28 @@ def main(args):
if args.rm_path is not None:
reward_model.load_state_dict(state_dict, strict=False)
- initial_model.to(torch.float16).to(torch.cuda.current_device())
- reward_model.to(torch.float16).to(torch.cuda.current_device())
+ initial_model.to(torch.bfloat16).to(torch.cuda.current_device())
+ reward_model.to(torch.bfloat16).to(torch.cuda.current_device())
- if args.model == 'gpt2':
+ if args.model == "gpt2":
actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
- elif args.model == 'bloom':
+ elif args.model == "bloom":
actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
- elif args.model == 'opt':
+ elif args.model == "opt":
actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
- elif args.model == 'llama':
+ elif args.model == "llama":
actor = LlamaActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
else:
raise ValueError(f'Unsupported actor model "{args.model}"')
- if rm_model_name == 'gpt2':
- critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
- elif rm_model_name == 'bloom':
- critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
- elif rm_model_name == 'opt':
- critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
- elif rm_model_name == 'llama':
- critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
+ if rm_model_name == "gpt2":
+ critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
+ elif rm_model_name == "bloom":
+ critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
+ elif rm_model_name == "opt":
+ critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
+ elif rm_model_name == "llama":
+ critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
else:
raise ValueError(f'Unsupported reward model "{rm_model_name}"')
@@ -94,65 +98,72 @@ def main(args):
critic.load_state_dict(state_dict, strict=False)
del state_dict
- if args.strategy != 'colossalai_gemini':
- critic.to(torch.float16).to(torch.cuda.current_device())
- actor.to(torch.float16).to(torch.cuda.current_device())
+ actor.to(torch.bfloat16).to(torch.cuda.current_device())
+ critic.to(torch.bfloat16).to(torch.cuda.current_device())
# configure optimizer
- if args.strategy.startswith('colossalai'):
- actor_optim = HybridAdam(actor.parameters(), lr=1e-7)
- critic_optim = HybridAdam(critic.parameters(), lr=1e-7)
+ if args.strategy.startswith("colossalai"):
+ actor_optim = HybridAdam(actor.parameters(), lr=args.lr)
+ critic_optim = HybridAdam(critic.parameters(), lr=args.lr)
else:
- actor_optim = Adam(actor.parameters(), lr=1e-7)
- critic_optim = Adam(critic.parameters(), lr=1e-7)
+ actor_optim = Adam(actor.parameters(), lr=args.lr)
+ critic_optim = Adam(critic.parameters(), lr=args.lr)
# configure tokenizer
- if args.model == 'gpt2':
- tokenizer = GPT2Tokenizer.from_pretrained(
- 'gpt2' if args.tokenizer is None else args.tokenizer)
+ if args.model == "gpt2":
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2" if args.tokenizer is None else args.tokenizer)
tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'bloom':
+ elif args.model == "bloom":
tokenizer = BloomTokenizerFast.from_pretrained(
- 'bigscience/bloom-560m' if args.tokenizer is None else args.tokenizer)
+ "bigscience/bloom-560m" if args.tokenizer is None else args.tokenizer
+ )
tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'opt':
- tokenizer = AutoTokenizer.from_pretrained(
- "facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
+ elif args.model == "opt":
+ tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'llama':
+ elif args.model == "llama":
tokenizer = LlamaTokenizer.from_pretrained(
- "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer)
- tokenizer.eos_token = '<\s>'
+ "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer
+ )
+ tokenizer.eos_token = "<\s>"
tokenizer.pad_token = tokenizer.unk_token
else:
raise ValueError(f'Unsupported model "{args.model}"')
-
- prompt_dataset = PromptDataset(tokenizer=tokenizer, data_path=args.prompt_dataset, max_datasets_size=16384)
+ # NOTE: generate() requires padding_side to be "left"
+ tokenizer.padding_side = "left"
+
+ prompt_dataset = PromptDataset(
+ tokenizer=tokenizer,
+ data_path=args.prompt_dataset,
+ max_datasets_size=args.max_datasets_size,
+ max_length=args.max_input_len,
+ )
if dist.is_initialized() and dist.get_world_size() > 1:
prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True)
else:
prompt_sampler = None
- prompt_dataloader = DataLoader(prompt_dataset,
- shuffle=(prompt_sampler is None),
- sampler=prompt_sampler,
- batch_size=args.experience_batch_size)
-
- pretrain_dataset = SupervisedDataset(tokenizer=tokenizer,
- data_path=args.pretrain_dataset,
- max_datasets_size=16384,
- max_length=args.max_input_len)
+ prompt_dataloader = DataLoader(
+ prompt_dataset, shuffle=(prompt_sampler is None), sampler=prompt_sampler, batch_size=args.experience_batch_size
+ )
+
+ pretrain_dataset = SupervisedDataset(
+ tokenizer=tokenizer,
+ data_path=args.pretrain_dataset,
+ max_datasets_size=args.max_datasets_size,
+ max_length=args.max_input_len,
+ )
if dist.is_initialized() and dist.get_world_size() > 1:
pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True)
else:
pretrain_sampler = None
- pretrain_dataloader = DataLoader(pretrain_dataset,
- shuffle=(pretrain_sampler is None),
- sampler=pretrain_sampler,
- batch_size=args.ptx_batch_size)
+ pretrain_dataloader = DataLoader(
+ pretrain_dataset, shuffle=(pretrain_sampler is None), sampler=pretrain_sampler, batch_size=args.ptx_batch_size
+ )
# NOTE: For small models like opt-1.3b, reward model and initial model are not required to be parallelized.
- (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = \
- strategy.prepare((actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
+ (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
+ (actor, actor_optim), (critic, critic_optim), reward_model, initial_model
+ )
# configure trainer
trainer = PPOTrainer(
@@ -163,6 +174,7 @@ def main(args):
initial_model,
actor_optim,
critic_optim,
+ tokenizer=tokenizer,
kl_coef=args.kl_coef,
ptx_coef=args.ptx_coef,
train_batch_size=args.train_batch_size,
@@ -171,52 +183,67 @@ def main(args):
do_sample=True,
temperature=1.0,
top_k=50,
- pad_token_id=tokenizer.pad_token_id,
- eos_token_id=tokenizer.eos_token_id,
- offload_inference_models=args.strategy != 'colossalai_gemini'
+ offload_inference_models=args.strategy != "colossalai_gemini",
)
- trainer.fit(prompt_dataloader=prompt_dataloader,
- pretrain_dataloader=pretrain_dataloader,
- num_episodes=args.num_episodes,
- num_collect_steps=args.num_collect_steps,
- num_update_steps=args.num_update_steps)
+ trainer.fit(
+ num_episodes=args.num_episodes,
+ num_collect_steps=args.num_collect_steps,
+ num_update_steps=args.num_update_steps,
+ prompt_dataloader=prompt_dataloader,
+ pretrain_dataloader=pretrain_dataloader,
+ log_dir=args.log_dir,
+ use_wandb=args.use_wandb,
+ )
+ if args.lora_rank > 0 and args.merge_lora_weights:
+ from coati.models.lora import LORA_MANAGER
+
+ # NOTE: set model to eval to merge LoRA weights
+ LORA_MANAGER.merge_weights = True
+ actor.eval()
# save model checkpoint after fitting
- strategy.save_model(actor, args.save_path, only_rank0=True)
+ strategy.save_pretrained(actor, path=args.save_path)
# save optimizer checkpoint on all ranks
if args.need_optim_ckpt:
- strategy.save_optimizer(actor_optim,
- 'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
- only_rank0=False)
+ strategy.save_optimizer(
+ actor_optim, "actor_optim_checkpoint_prompts_%d.pt" % (torch.cuda.current_device()), only_rank0=False
+ )
-if __name__ == '__main__':
+if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('--prompt_dataset', type=str, default=None, help='path to the prompt dataset')
- parser.add_argument('--pretrain_dataset', type=str, default=None, help='path to the pretrained dataset')
- parser.add_argument('--strategy',
- choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'],
- default='colossalai_zero2',
- help='strategy to use')
- parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
- parser.add_argument('--tokenizer', type=str, default=None)
- parser.add_argument('--pretrain', type=str, default=None)
- parser.add_argument('--rm_model', default=None, choices=['gpt2', 'bloom', 'opt', 'llama'])
- parser.add_argument('--rm_path', type=str, default=None)
- parser.add_argument('--rm_pretrain', type=str, default=None)
- parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts')
- parser.add_argument('--need_optim_ckpt', type=bool, default=False)
- parser.add_argument('--num_episodes', type=int, default=10)
- parser.add_argument('--num_collect_steps', type=int, default=10)
- parser.add_argument('--num_update_steps', type=int, default=5)
- parser.add_argument('--train_batch_size', type=int, default=8)
- parser.add_argument('--ptx_batch_size', type=int, default=1)
- parser.add_argument('--experience_batch_size', type=int, default=8)
- parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
- parser.add_argument('--kl_coef', type=float, default=0.1)
- parser.add_argument('--ptx_coef', type=float, default=0.9)
- parser.add_argument('--max_input_len', type=int, default=96)
- parser.add_argument('--max_seq_len', type=int, default=128)
+ parser.add_argument("--prompt_dataset", type=str, default=None, help="path to the prompt dataset")
+ parser.add_argument("--pretrain_dataset", type=str, default=None, help="path to the pretrained dataset")
+ parser.add_argument("--max_datasets_size", type=int, default=50000)
+ parser.add_argument(
+ "--strategy",
+ choices=["ddp", "colossalai_gemini", "colossalai_zero2"],
+ default="colossalai_zero2",
+ help="strategy to use",
+ )
+ parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
+ parser.add_argument("--tokenizer", type=str, default=None)
+ parser.add_argument("--pretrain", type=str, default=None)
+ parser.add_argument("--rm_model", default=None, choices=["gpt2", "bloom", "opt", "llama"])
+ parser.add_argument("--rm_path", type=str, default=None)
+ parser.add_argument("--rm_pretrain", type=str, default=None)
+ parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts")
+ parser.add_argument("--need_optim_ckpt", type=bool, default=False)
+ parser.add_argument("--num_episodes", type=int, default=10)
+ parser.add_argument("--num_collect_steps", type=int, default=10)
+ parser.add_argument("--num_update_steps", type=int, default=5)
+ parser.add_argument("--train_batch_size", type=int, default=8)
+ parser.add_argument("--ptx_batch_size", type=int, default=1)
+ parser.add_argument("--experience_batch_size", type=int, default=8)
+ parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
+ parser.add_argument("--merge_lora_weights", type=bool, default=True)
+ parser.add_argument("--lr", type=float, default=1e-7)
+ parser.add_argument("--kl_coef", type=float, default=0.1)
+ parser.add_argument("--ptx_coef", type=float, default=0.9)
+ parser.add_argument("--max_input_len", type=int, default=96)
+ parser.add_argument("--max_seq_len", type=int, default=128)
+ parser.add_argument("--log_dir", default="logs", type=str)
+ parser.add_argument("--use_wandb", default=False, action="store_true")
args = parser.parse_args()
main(args)
diff --git a/applications/Chat/examples/train_reward_model.py b/applications/Chat/examples/train_reward_model.py
index 190460bc20f6..df6e8b6bdc26 100644
--- a/applications/Chat/examples/train_reward_model.py
+++ b/applications/Chat/examples/train_reward_model.py
@@ -1,5 +1,5 @@
import argparse
-from random import randint
+import warnings
import torch
import torch.distributed as dist
@@ -24,65 +24,69 @@
def train(args):
# configure strategy
- if args.strategy == 'ddp':
+ if args.strategy == "ddp":
strategy = DDPStrategy()
- elif args.strategy == 'colossalai_gemini':
- strategy = GeminiStrategy(placement_policy='cuda')
- elif args.strategy == 'colossalai_zero2':
- strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
+ elif args.strategy == "colossalai_gemini":
+ strategy = GeminiStrategy(placement_policy="auto")
+ elif args.strategy == "colossalai_zero2":
+ strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
else:
raise ValueError(f'Unsupported strategy "{args.strategy}"')
# configure model
+ if args.lora_rank > 0:
+ warnings.warn("Lora is not supported yet.")
+ args.lora_rank = 0
+
with strategy.model_init_context():
- if args.model == 'bloom':
+ if args.model == "bloom":
model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
- elif args.model == 'opt':
+ elif args.model == "opt":
model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
- elif args.model == 'gpt2':
+ elif args.model == "gpt2":
model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
- elif args.model == 'llama':
+ elif args.model == "llama":
model = LlamaRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
else:
raise ValueError(f'Unsupported model "{args.model}"')
- model.to(torch.float16).to(torch.cuda.current_device())
+ model.to(torch.bfloat16).to(torch.cuda.current_device())
if args.model_path is not None:
state_dict = torch.load(args.model_path)
model.load_state_dict(state_dict)
# configure tokenizer
- if args.model == 'gpt2':
- tokenizer = GPT2Tokenizer.from_pretrained(
- 'gpt2' if args.tokenizer is None else args.tokenizer)
+ if args.model == "gpt2":
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2" if args.tokenizer is None else args.tokenizer)
tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'bloom':
+ elif args.model == "bloom":
tokenizer = BloomTokenizerFast.from_pretrained(
- 'bigscience/bloom-560m' if args.tokenizer is None else args.tokenizer)
+ "bigscience/bloom-560m" if args.tokenizer is None else args.tokenizer
+ )
tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'opt':
- tokenizer = AutoTokenizer.from_pretrained(
- "facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
+ elif args.model == "opt":
+ tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'llama':
+ elif args.model == "llama":
tokenizer = LlamaTokenizer.from_pretrained(
- "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer)
- tokenizer.eos_token = '<\s>'
+ "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer
+ )
+ tokenizer.eos_token = "<\s>"
tokenizer.pad_token = tokenizer.unk_token
else:
raise ValueError(f'Unsupported model "{args.model}"')
# configure optimizer
- if args.strategy.startswith('colossalai'):
- optim = HybridAdam(model.parameters(), lr=5e-6)
+ if args.strategy.startswith("colossalai"):
+ optim = HybridAdam(model.parameters(), lr=args.lr)
else:
- optim = Adam(model.parameters(), lr=5e-6)
+ optim = Adam(model.parameters(), lr=args.lr)
# configure loss function
- if args.loss_fn == 'log_sig':
+ if args.loss_fn == "log_sig":
loss_fn = LogSigLoss()
- elif args.loss_fn == 'log_exp':
+ elif args.loss_fn == "log_exp":
loss_fn = LogExpLoss()
else:
raise ValueError(f'Unsupported loss function "{args.loss_fn}"')
@@ -93,110 +97,112 @@ def train(args):
else:
data = load_dataset(args.dataset)
- if args.test:
- train_data = data['train'].select(range(20))
- eval_data = data['test'].select(range(5))
- else:
- train_data = data['train']
- eval_data = data['test']
- valid_data = data['test'].select((randint(0, len(eval_data) - 1) for _ in range(len(eval_data) // 5)))
+ train_data = data["train"].select(range(min(args.max_datasets_size, len(data["train"]))))
+ eval_data = data["test"].select(range(min(args.max_datasets_size, len(data["test"]))))
- if args.dataset == 'Dahoas/rm-static':
+ if args.dataset == "Dahoas/rm-static":
train_dataset = RmStaticDataset(train_data, tokenizer, args.max_len)
- valid_dataset = RmStaticDataset(valid_data, tokenizer, args.max_len)
eval_dataset = RmStaticDataset(eval_data, tokenizer, args.max_len)
- elif args.dataset == 'Anthropic/hh-rlhf':
+ elif args.dataset == "Anthropic/hh-rlhf":
train_dataset = HhRlhfDataset(train_data, tokenizer, args.max_len)
- valid_dataset = HhRlhfDataset(valid_data, tokenizer, args.max_len)
eval_dataset = HhRlhfDataset(eval_data, tokenizer, args.max_len)
else:
raise ValueError(f'Unsupported dataset "{args.dataset}"')
if dist.is_initialized() and dist.get_world_size() > 1:
- train_sampler = DistributedSampler(train_dataset,
- shuffle=True,
- seed=42,
- drop_last=True,
- rank=dist.get_rank(),
- num_replicas=dist.get_world_size())
- valid_sampler = DistributedSampler(valid_dataset,
- shuffle=True,
- seed=42,
- drop_last=True,
- rank=dist.get_rank(),
- num_replicas=dist.get_world_size())
- eval_sampler = DistributedSampler(eval_dataset,
- shuffle=True,
- seed=42,
- drop_last=True,
- rank=dist.get_rank(),
- num_replicas=dist.get_world_size())
+ train_sampler = DistributedSampler(
+ train_dataset,
+ shuffle=True,
+ seed=42,
+ drop_last=True,
+ rank=dist.get_rank(),
+ num_replicas=dist.get_world_size(),
+ )
+ eval_sampler = DistributedSampler(
+ eval_dataset,
+ shuffle=True,
+ seed=42,
+ drop_last=True,
+ rank=dist.get_rank(),
+ num_replicas=dist.get_world_size(),
+ )
else:
train_sampler = None
- valid_sampler = None
eval_sampler = None
- train_dataloader = DataLoader(train_dataset,
- shuffle=(train_sampler is None),
- sampler=train_sampler,
- batch_size=args.batch_size,
- pin_memory=True)
-
- valid_dataloader = DataLoader(valid_dataset,
- shuffle=(valid_sampler is None),
- sampler=valid_sampler,
- batch_size=args.batch_size,
- pin_memory=True)
+ train_dataloader = DataLoader(
+ train_dataset,
+ shuffle=(train_sampler is None),
+ sampler=train_sampler,
+ batch_size=args.batch_size,
+ pin_memory=True,
+ )
- eval_dataloader = DataLoader(eval_dataset,
- shuffle=(eval_sampler is None),
- sampler=eval_sampler,
- batch_size=args.batch_size,
- pin_memory=True)
+ eval_dataloader = DataLoader(
+ eval_dataset, shuffle=(eval_sampler is None), sampler=eval_sampler, batch_size=args.batch_size, pin_memory=True
+ )
lr_scheduler = CosineAnnealingLR(optim, train_dataloader.__len__() // 100)
strategy_dict = strategy.prepare(dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler))
- model = strategy_dict['model']
- optim = strategy_dict['optimizer']
- lr_scheduler = strategy_dict['lr_scheduler']
- trainer = RewardModelTrainer(model=model,
- strategy=strategy,
- optim=optim,
- lr_scheduler=lr_scheduler,
- loss_fn=loss_fn,
- max_epochs=args.max_epochs)
-
- trainer.fit(train_dataloader=train_dataloader, valid_dataloader=valid_dataloader, eval_dataloader=eval_dataloader)
+ model = strategy_dict["model"]
+ optim = strategy_dict["optimizer"]
+ lr_scheduler = strategy_dict["lr_scheduler"]
+ trainer = RewardModelTrainer(
+ model=model,
+ strategy=strategy,
+ optim=optim,
+ lr_scheduler=lr_scheduler,
+ loss_fn=loss_fn,
+ max_epochs=args.max_epochs,
+ )
+
+ trainer.fit(
+ train_dataloader=train_dataloader,
+ eval_dataloader=eval_dataloader,
+ log_dir=args.log_dir,
+ use_wandb=args.use_wandb,
+ )
+
+ if args.lora_rank > 0 and args.merge_lora_weights:
+ from coati.models.lora import LORA_MANAGER
+
+ # NOTE: set model to eval to merge LoRA weights
+ LORA_MANAGER.merge_weights = True
+ model.eval()
# save model checkpoint after fitting on only rank0
- strategy.save_model(model, args.save_path, only_rank0=True)
+ state_dict = model.state_dict()
+ torch.save(state_dict, args.save_path)
# save optimizer checkpoint on all ranks
if args.need_optim_ckpt:
- strategy.save_optimizer(trainer.optimizer,
- 'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()),
- only_rank0=False)
+ strategy.save_optimizer(
+ trainer.optimizer, "rm_optim_checkpoint_%d.pt" % (torch.cuda.current_device()), only_rank0=False
+ )
-if __name__ == '__main__':
+if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('--strategy',
- choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'],
- default='colossalai_zero2')
- parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom')
- parser.add_argument('--tokenizer', type=str, default=None)
- parser.add_argument('--pretrain', type=str, default=None)
- parser.add_argument('--model_path', type=str, default=None)
- parser.add_argument('--need_optim_ckpt', type=bool, default=False)
- parser.add_argument('--dataset',
- type=str,
- choices=['Anthropic/hh-rlhf', 'Dahoas/rm-static'],
- default='Dahoas/rm-static')
- parser.add_argument('--subset', type=lambda x: None if x == 'None' else x, default=None)
- parser.add_argument('--save_path', type=str, default='rm_ckpt')
- parser.add_argument('--max_epochs', type=int, default=1)
- parser.add_argument('--batch_size', type=int, default=1)
- parser.add_argument('--max_len', type=int, default=512)
- parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
- parser.add_argument('--loss_fn', type=str, default='log_sig', choices=['log_sig', 'log_exp'])
- parser.add_argument('--test', type=bool, default=False)
+ parser.add_argument(
+ "--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="colossalai_zero2"
+ )
+ parser.add_argument("--model", choices=["gpt2", "bloom", "opt", "llama"], default="bloom")
+ parser.add_argument("--tokenizer", type=str, default=None)
+ parser.add_argument("--pretrain", type=str, default=None)
+ parser.add_argument("--model_path", type=str, default=None)
+ parser.add_argument("--need_optim_ckpt", type=bool, default=False)
+ parser.add_argument(
+ "--dataset", type=str, choices=["Anthropic/hh-rlhf", "Dahoas/rm-static"], default="Dahoas/rm-static"
+ )
+ parser.add_argument("--subset", type=lambda x: None if x == "None" else x, default=None)
+ parser.add_argument("--max_datasets_size", type=int, default=1000000)
+ parser.add_argument("--save_path", type=str, default="rm_ckpt")
+ parser.add_argument("--max_epochs", type=int, default=1)
+ parser.add_argument("--batch_size", type=int, default=1)
+ parser.add_argument("--max_len", type=int, default=512)
+ parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
+ parser.add_argument("--merge_lora_weights", type=bool, default=True)
+ parser.add_argument("--lr", type=float, default=9e-6)
+ parser.add_argument("--loss_fn", type=str, default="log_sig", choices=["log_sig", "log_exp"])
+ parser.add_argument("--log_dir", default="logs", type=str)
+ parser.add_argument("--use_wandb", default=False, action="store_true")
args = parser.parse_args()
train(args)
diff --git a/applications/Chat/examples/train_rm.sh b/applications/Chat/examples/train_rm.sh
index cc1b7be2815f..c5ebaf708ddc 100755
--- a/applications/Chat/examples/train_rm.sh
+++ b/applications/Chat/examples/train_rm.sh
@@ -16,7 +16,10 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
set_n_least_used_CUDA_VISIBLE_DEVICES 2
torchrun --standalone --nproc_per_node=2 train_reward_model.py \
- --model 'bloom' \
+ --pretrain 'gpt2' \
+ --model 'gpt2' \
--strategy colossalai_zero2 \
- --loss_fn 'log_sig' \
- --dataset 'Anthropic/hh-rlhf'
+ --loss_fn 'log_exp' \
+ --dataset 'Anthropic/hh-rlhf' \
+ --batch_size 16 \
+ --max_epochs 10
diff --git a/applications/Chat/examples/train_sft.py b/applications/Chat/examples/train_sft.py
index f068ea2bf5de..66d08da30120 100644
--- a/applications/Chat/examples/train_sft.py
+++ b/applications/Chat/examples/train_sft.py
@@ -6,206 +6,216 @@
import torch.distributed as dist
from coati.dataset import SFTDataset, SupervisedDataset
from coati.models.bloom import BLOOMActor
+from coati.models.chatglm import ChatGLMActor
+from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
from coati.models.gpt import GPTActor
from coati.models.llama import LlamaActor
from coati.models.opt import OPTActor
-from coati.models.chatglm import ChatGLMActor
from coati.trainer import SFTTrainer
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
from datasets import load_dataset
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
-from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer, AutoModel
-from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
+from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from transformers.trainer import get_scheduler
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import HybridAdam
-from colossalai.tensor import ColoParameter
def train(args):
# configure strategy
- if args.strategy == 'ddp':
+ if args.strategy == "ddp":
strategy = DDPStrategy()
- elif args.strategy == 'colossalai_gemini':
- strategy = GeminiStrategy(placement_policy='cuda')
- elif args.strategy == 'colossalai_zero2':
- strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
- elif args.strategy == 'colossalai_zero2_cpu':
- strategy = LowLevelZeroStrategy(stage=2, placement_policy='cpu')
+ elif args.strategy == "colossalai_gemini":
+ strategy = GeminiStrategy(placement_policy="auto")
+ elif args.strategy == "colossalai_zero2":
+ strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
+ elif args.strategy == "colossalai_zero2_cpu":
+ strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
else:
raise ValueError(f'Unsupported strategy "{args.strategy}"')
# configure model
if args.lora_rank > 0:
- warnings.warn("Gradient checkpoint is disabled when using LoRA")
- args.grad_checkpoint = False
+ warnings.warn("Lora is not supported yet.")
+ args.lora_rank = 0
+
with strategy.model_init_context():
- if args.model == 'bloom':
- model = BLOOMActor(pretrained=args.pretrain,
- lora_rank=args.lora_rank,
- checkpoint=args.grad_checkpoint)
- elif args.model == 'opt':
- model = OPTActor(pretrained=args.pretrain,
- lora_rank=args.lora_rank,
- checkpoint=args.grad_checkpoint)
- elif args.model == 'gpt2':
- model = GPTActor(pretrained=args.pretrain,
- lora_rank=args.lora_rank,
- checkpoint=args.grad_checkpoint)
- elif args.model == 'llama':
- model = LlamaActor(pretrained=args.pretrain,
- lora_rank=args.lora_rank,
- checkpoint=args.grad_checkpoint)
- elif args.model == 'chatglm':
+ if args.model == "bloom":
+ model = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint)
+ elif args.model == "opt":
+ model = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint)
+ elif args.model == "gpt2":
+ model = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint)
+ elif args.model == "llama":
+ model = LlamaActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint)
+ elif args.model == "chatglm":
model = ChatGLMActor(pretrained=args.pretrain)
else:
raise ValueError(f'Unsupported model "{args.model}"')
- model.to(torch.float16).to(torch.cuda.current_device())
+ model.to(torch.bfloat16).to(torch.cuda.current_device())
# configure tokenizer
- if args.model == 'gpt2':
- tokenizer = GPT2Tokenizer.from_pretrained(
- 'gpt2' if args.tokenizer is None else args.tokenizer)
+ if args.model == "gpt2":
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2" if args.tokenizer is None else args.tokenizer)
tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'bloom':
+ elif args.model == "bloom":
tokenizer = BloomTokenizerFast.from_pretrained(
- 'bigscience/bloom-560m' if args.tokenizer is None else args.tokenizer)
+ "bigscience/bloom-560m" if args.tokenizer is None else args.tokenizer
+ )
tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'opt':
- tokenizer = AutoTokenizer.from_pretrained(
- "facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
+ elif args.model == "opt":
+ tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'llama':
+ elif args.model == "llama":
tokenizer = LlamaTokenizer.from_pretrained(
- "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer)
- tokenizer.eos_token = '<\s>'
+ "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer
+ )
+ tokenizer.eos_token = "<\s>"
tokenizer.pad_token = tokenizer.unk_token
- elif args.model == 'chatglm':
+ elif args.model == "chatglm":
tokenizer = ChatGLMTokenizer.from_pretrained(
- "THUDM/chatglm-6b" if args.tokenizer is None else args.tokenizer, trust_remote_code=True)
+ "THUDM/chatglm-6b" if args.tokenizer is None else args.tokenizer, trust_remote_code=True
+ )
else:
raise ValueError(f'Unsupported model "{args.model}"')
- if args.model == 'llama' and args.strategy == 'colossalai_gemini':
- # this is a hack to deal with the resized embedding
- # to make sure all parameters are ColoParameter for Colossal-AI Gemini Compatibility
- for name, param in model.named_parameters():
- if not isinstance(param, ColoParameter):
- sub_module_name = '.'.join(name.split('.')[:-1])
- weight_name = name.split('.')[-1]
- sub_module = model.get_submodule(sub_module_name)
- setattr(sub_module, weight_name, ColoParameter(param))
-
# configure optimizer
- if args.strategy.startswith('colossalai'):
+ if args.strategy.startswith("colossalai"):
optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0)
else:
optim = Adam(model.parameters(), lr=args.lr)
- logger = get_dist_logger()
# configure dataset
- if args.dataset == 'yizhongw/self_instruct':
- train_data = load_dataset(args.dataset, 'super_natural_instructions', split='train')
- eval_data = load_dataset(args.dataset, 'super_natural_instructions', split='test')
+ if args.dataset == "yizhongw/self_instruct":
+ train_data = load_dataset(args.dataset, "super_natural_instructions", split="train")
+ eval_data = load_dataset(args.dataset, "super_natural_instructions", split="test")
+
+ if args.max_datasets_size is not None:
+ train_data = train_data.select(range(min(args.max_datasets_size, len(train_data))))
+ eval_data = eval_data.select(range(min(args.max_datasets_size, len(eval_data))))
train_dataset = SFTDataset(train_data, tokenizer, args.max_len)
eval_dataset = SFTDataset(eval_data, tokenizer, args.max_len)
else:
- train_dataset = SupervisedDataset(tokenizer=tokenizer,
- data_path=args.dataset,
- max_datasets_size=args.max_datasets_size,
- max_length=args.max_len)
+ train_dataset = SupervisedDataset(
+ tokenizer=tokenizer,
+ data_path=args.dataset,
+ max_datasets_size=args.max_datasets_size,
+ max_length=args.max_len,
+ )
eval_dataset = None
if dist.is_initialized() and dist.get_world_size() > 1:
- train_sampler = DistributedSampler(train_dataset,
- shuffle=True,
- seed=42,
- drop_last=True,
- rank=dist.get_rank(),
- num_replicas=dist.get_world_size())
+ train_sampler = DistributedSampler(
+ train_dataset,
+ shuffle=True,
+ seed=42,
+ drop_last=True,
+ rank=dist.get_rank(),
+ num_replicas=dist.get_world_size(),
+ )
if eval_dataset is not None:
- eval_sampler = DistributedSampler(eval_dataset,
- shuffle=False,
- seed=42,
- drop_last=False,
- rank=dist.get_rank(),
- num_replicas=dist.get_world_size())
+ eval_sampler = DistributedSampler(
+ eval_dataset,
+ shuffle=False,
+ seed=42,
+ drop_last=False,
+ rank=dist.get_rank(),
+ num_replicas=dist.get_world_size(),
+ )
else:
train_sampler = None
eval_sampler = None
- train_dataloader = DataLoader(train_dataset,
- shuffle=(train_sampler is None),
- sampler=train_sampler,
- batch_size=args.batch_size,
- pin_memory=True)
+ train_dataloader = DataLoader(
+ train_dataset,
+ shuffle=(train_sampler is None),
+ sampler=train_sampler,
+ batch_size=args.batch_size,
+ pin_memory=True,
+ )
if eval_dataset is not None:
- eval_dataloader = DataLoader(eval_dataset,
- shuffle=(eval_sampler is None),
- sampler=eval_sampler,
- batch_size=args.batch_size,
- pin_memory=True)
+ eval_dataloader = DataLoader(
+ eval_dataset,
+ shuffle=(eval_sampler is None),
+ sampler=eval_sampler,
+ batch_size=args.batch_size,
+ pin_memory=True,
+ )
else:
eval_dataloader = None
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
max_steps = math.ceil(args.max_epochs * num_update_steps_per_epoch)
- lr_scheduler = get_scheduler("cosine",
- optim,
- num_warmup_steps=math.ceil(max_steps * 0.03),
- num_training_steps=max_steps)
+ lr_scheduler = get_scheduler(
+ "cosine", optim, num_warmup_steps=math.ceil(max_steps * 0.03), num_training_steps=max_steps
+ )
strategy_dict = strategy.prepare(dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler))
- model = strategy_dict['model']
- optim = strategy_dict['optimizer']
- lr_scheduler = strategy_dict['lr_scheduler']
- trainer = SFTTrainer(model=model,
- strategy=strategy,
- optim=optim,
- lr_scheduler=lr_scheduler,
- max_epochs=args.max_epochs,
- accumulation_steps=args.accumulation_steps)
-
- trainer.fit(train_dataloader=train_dataloader,
- eval_dataloader=eval_dataloader,
- logger=logger,
- use_wandb=args.use_wandb)
+ model = strategy_dict["model"]
+ optim = strategy_dict["optimizer"]
+ lr_scheduler = strategy_dict["lr_scheduler"]
+ trainer = SFTTrainer(
+ model=model,
+ strategy=strategy,
+ optim=optim,
+ lr_scheduler=lr_scheduler,
+ max_epochs=args.max_epochs,
+ accumulation_steps=args.accumulation_steps,
+ )
+ logger = get_dist_logger()
+ trainer.fit(
+ train_dataloader=train_dataloader,
+ eval_dataloader=eval_dataloader,
+ logger=logger,
+ log_dir=args.log_dir,
+ use_wandb=args.use_wandb,
+ )
+
+ if args.lora_rank > 0 and args.merge_lora_weights:
+ from coati.models.lora import LORA_MANAGER
+
+ # NOTE: set model to eval to merge LoRA weights
+ LORA_MANAGER.merge_weights = True
+ model.eval()
# save model checkpoint after fitting on only rank0
- strategy.save_pretrained(model, path=args.save_path, only_rank0=True, tokenizer=tokenizer)
+ strategy.save_pretrained(model, path=args.save_path, tokenizer=tokenizer)
# save optimizer checkpoint on all ranks
if args.need_optim_ckpt:
- strategy.save_optimizer(trainer.optimizer,
- 'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()),
- only_rank0=False)
+ strategy.save_optimizer(
+ trainer.optimizer, "rm_optim_checkpoint_%d.pt" % (torch.cuda.current_device()), only_rank0=False
+ )
-if __name__ == '__main__':
+if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('--strategy',
- choices=['ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_zero2_cpu'],
- default='colossalai_zero2')
- parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama', 'chatglm'], default='bloom')
- parser.add_argument('--tokenizer', type=str, default=None)
- parser.add_argument('--pretrain', type=str, default=None)
- parser.add_argument('--dataset', type=str, default=None)
- parser.add_argument('--max_datasets_size', type=int, default=None)
- parser.add_argument('--save_path', type=str, default='output')
- parser.add_argument('--need_optim_ckpt', type=bool, default=False)
- parser.add_argument('--max_epochs', type=int, default=3)
- parser.add_argument('--batch_size', type=int, default=4)
- parser.add_argument('--max_len', type=int, default=512)
- parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
- parser.add_argument('--log_interval', type=int, default=100, help="how many steps to log")
- parser.add_argument('--lr', type=float, default=5e-6)
- parser.add_argument('--accumulation_steps', type=int, default=8)
- parser.add_argument('--use_wandb', default=False, action='store_true')
- parser.add_argument('--grad_checkpoint', default=False, action='store_true')
+ parser.add_argument(
+ "--strategy",
+ choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_zero2_cpu"],
+ default="colossalai_zero2",
+ )
+ parser.add_argument("--model", choices=["gpt2", "bloom", "opt", "llama", "chatglm"], default="bloom")
+ parser.add_argument("--tokenizer", type=str, default=None)
+ parser.add_argument("--pretrain", type=str, default=None)
+ parser.add_argument("--dataset", type=str, default=None)
+ parser.add_argument("--max_datasets_size", type=int, default=None)
+ parser.add_argument("--save_path", type=str, default="output")
+ parser.add_argument("--need_optim_ckpt", type=bool, default=False)
+ parser.add_argument("--max_epochs", type=int, default=3)
+ parser.add_argument("--batch_size", type=int, default=4)
+ parser.add_argument("--max_len", type=int, default=512)
+ parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
+ parser.add_argument("--merge_lora_weights", type=bool, default=True)
+ parser.add_argument("--lr", type=float, default=5e-6)
+ parser.add_argument("--accumulation_steps", type=int, default=8)
+ parser.add_argument("--log_dir", default="logs", type=str)
+ parser.add_argument("--use_wandb", default=False, action="store_true")
+ parser.add_argument("--grad_checkpoint", default=False, action="store_true")
args = parser.parse_args()
train(args)
diff --git a/applications/Chat/examples/train_sft.sh b/applications/Chat/examples/train_sft.sh
index 1a5cd069011d..0fb4da3d3ce8 100755
--- a/applications/Chat/examples/train_sft.sh
+++ b/applications/Chat/examples/train_sft.sh
@@ -19,7 +19,6 @@ torchrun --standalone --nproc_per_node=4 train_sft.py \
--pretrain "/path/to/LLaMa-7B/" \
--model 'llama' \
--strategy colossalai_zero2 \
- --log_interval 10 \
--save_path /path/to/Coati-7B \
--dataset /path/to/data.json \
--batch_size 4 \
diff --git a/applications/Chat/inference/benchmark.py b/applications/Chat/inference/benchmark.py
index 438a1e3ef1c7..dbb5490a63dc 100644
--- a/applications/Chat/inference/benchmark.py
+++ b/applications/Chat/inference/benchmark.py
@@ -84,28 +84,34 @@ def evaluate(
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
- 'pretrained',
- help='Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.')
- parser.add_argument('--quant',
- choices=['8bit', '4bit'],
- default=None,
- help='Quantization mode. Default: None (no quantization, fp16).')
+ "pretrained",
+ help="Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.",
+ )
+ parser.add_argument(
+ "--quant",
+ choices=["8bit", "4bit"],
+ default=None,
+ help="Quantization mode. Default: None (no quantization, fp16).",
+ )
parser.add_argument(
- '--gptq_checkpoint',
+ "--gptq_checkpoint",
default=None,
- help='Path to GPTQ checkpoint. This is only useful when quantization mode is 4bit. Default: None.')
- parser.add_argument('--gptq_group_size',
- type=int,
- default=128,
- help='Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.')
+ help="Path to GPTQ checkpoint. This is only useful when quantization mode is 4bit. Default: None.",
+ )
+ parser.add_argument(
+ "--gptq_group_size",
+ type=int,
+ default=128,
+ help="Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.",
+ )
args = parser.parse_args()
- if args.quant == '4bit':
- assert args.gptq_checkpoint is not None, 'Please specify a GPTQ checkpoint.'
+ if args.quant == "4bit":
+ assert args.gptq_checkpoint is not None, "Please specify a GPTQ checkpoint."
tokenizer = AutoTokenizer.from_pretrained(args.pretrained)
- if args.quant == '4bit':
+ if args.quant == "4bit":
with low_resource_init():
config = LlamaConfig.from_pretrained(args.pretrained)
model = LlamaForCausalLM(config)
@@ -114,12 +120,12 @@ def evaluate(
else:
model = LlamaForCausalLM.from_pretrained(
args.pretrained,
- load_in_8bit=(args.quant == '8bit'),
+ load_in_8bit=(args.quant == "8bit"),
torch_dtype=torch.float16,
device_map="auto",
)
- if args.quant != '8bit':
- model.half() # seems to fix bugs for some users.
+ if args.quant != "8bit":
+ model.half() # seems to fix bugs for some users.
model.eval()
total_tokens = 0
@@ -129,7 +135,7 @@ def evaluate(
resp, tokens = evaluate(model, tokenizer, instruction, temperature=0.2, num_beams=1)
total_tokens += tokens
print(f"Response: {resp}")
- print('\n----------------------------\n')
+ print("\n----------------------------\n")
duration = time() - start
- print(f'Total time: {duration:.3f} s, {total_tokens/duration:.3f} tokens/s')
- print(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.3f} GB')
+ print(f"Total time: {duration:.3f} s, {total_tokens/duration:.3f} tokens/s")
+ print(f"Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.3f} GB")
diff --git a/applications/Chat/inference/locustfile.py b/applications/Chat/inference/locustfile.py
index 9443d4b99180..333262e538ac 100644
--- a/applications/Chat/inference/locustfile.py
+++ b/applications/Chat/inference/locustfile.py
@@ -1,26 +1,26 @@
-from json import JSONDecodeError
-
from locust import HttpUser, task
-samples = [[
- dict(
- instruction='Who is the best player in the history of NBA?',
- response='The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
- ),
- dict(instruction='continue this talk', response=''),
-], [
- dict(instruction='Who is the best player in the history of NBA?', response=''),
-]]
+samples = [
+ [
+ dict(
+ instruction="Who is the best player in the history of NBA?",
+ response="The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1",
+ ),
+ dict(instruction="continue this talk", response=""),
+ ],
+ [
+ dict(instruction="Who is the best player in the history of NBA?", response=""),
+ ],
+]
class GenerationUser(HttpUser):
-
@task
def generate(self):
for sample in samples:
- data = {'max_new_tokens': 64, 'history': sample}
- with self.client.post('/generate', json=data, catch_response=True) as response:
+ data = {"max_new_tokens": 64, "history": sample}
+ with self.client.post("/generate", json=data, catch_response=True) as response:
if response.status_code in (200, 406):
response.success()
else:
- response.failure('Response wrong')
+ response.failure("Response wrong")
diff --git a/applications/Chat/inference/server.py b/applications/Chat/inference/server.py
index 9d6b7fabef54..7c6a61b9e7f2 100644
--- a/applications/Chat/inference/server.py
+++ b/applications/Chat/inference/server.py
@@ -16,7 +16,7 @@
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
from utils import ChatPromptProcessor, Dialogue, LockedIterator, load_json, sample_streamingly, update_model_kwargs_fn
-CONTEXT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.'
+CONTEXT = "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions."
MAX_LEN = 512
running_lock = Lock()
@@ -36,11 +36,11 @@ class GenerationTaskReq(BaseModel):
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# set CORS
-origin_spec_from_env = os.environ.get('CORS_ORIGIN', None)
+origin_spec_from_env = os.environ.get("CORS_ORIGIN", None)
if origin_spec_from_env is not None:
# allow CORS from the specified origins
- origins = os.environ['CORS_ORIGIN'].split(',')
+ origins = os.environ["CORS_ORIGIN"].split(",")
else:
# allow CORS from all origins
origins = ["*"]
@@ -58,13 +58,13 @@ def generate_streamingly(prompt, max_new_tokens, top_k, top_p, temperature):
inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()}
# TODO(ver217): streaming generation does not support repetition_penalty now
model_kwargs = {
- 'max_generate_tokens': max_new_tokens,
- 'early_stopping': True,
- 'top_k': top_k,
- 'top_p': top_p,
- 'temperature': temperature,
- 'prepare_inputs_fn': model.prepare_inputs_for_generation,
- 'update_model_kwargs_fn': update_model_kwargs_fn,
+ "max_generate_tokens": max_new_tokens,
+ "early_stopping": True,
+ "top_k": top_k,
+ "top_p": top_p,
+ "temperature": temperature,
+ "prepare_inputs_fn": model.prepare_inputs_for_generation,
+ "update_model_kwargs_fn": update_model_kwargs_fn,
}
is_first_word = True
generator = LockedIterator(sample_streamingly(model, **inputs, **model_kwargs), running_lock)
@@ -81,9 +81,9 @@ def generate_streamingly(prompt, max_new_tokens, top_k, top_p, temperature):
if is_first_word:
out_string = out_string.lstrip()
is_first_word = False
- elif current_sub_tokens[0].startswith('▁'):
+ elif current_sub_tokens[0].startswith("▁"):
# whitespace will be ignored by the frontend
- out_string = ' ' + out_string
+ out_string = " " + out_string
yield out_string
@@ -92,32 +92,33 @@ async def event_generator(request: Request, generator: Generator):
if await request.is_disconnected():
break
try:
- yield {'event': 'generate', 'data': next(generator)}
+ yield {"event": "generate", "data": next(generator)}
except StopIteration:
- yield {'event': 'end', 'data': ''}
+ yield {"event": "end", "data": ""}
break
-@app.post('/generate/stream')
-@limiter.limit('1/second')
+@app.post("/generate/stream")
+@limiter.limit("1/second")
def generate(data: GenerationTaskReq, request: Request):
prompt = prompt_processor.preprocess_prompt(data.history, data.max_new_tokens)
event_source = event_generator(
- request, generate_streamingly(prompt, data.max_new_tokens, data.top_k, data.top_p, data.temperature))
+ request, generate_streamingly(prompt, data.max_new_tokens, data.top_k, data.top_p, data.temperature)
+ )
return EventSourceResponse(event_source)
-@app.post('/generate')
-@limiter.limit('1/second')
+@app.post("/generate")
+@limiter.limit("1/second")
def generate_no_stream(data: GenerationTaskReq, request: Request):
prompt = prompt_processor.preprocess_prompt(data.history, data.max_new_tokens)
if prompt_processor.has_censored_words(prompt):
return prompt_processor.SAFE_RESPONSE
inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()}
with running_lock:
- output = model.generate(**inputs, **data.dict(exclude={'history'}))
+ output = model.generate(**inputs, **data.dict(exclude={"history"}))
output = output.cpu()
- prompt_len = inputs['input_ids'].size(1)
+ prompt_len = inputs["input_ids"].size(1)
response = output[0, prompt_len:]
out_string = tokenizer.decode(response, skip_special_tokens=True)
out_string = prompt_processor.postprocess_output(out_string)
@@ -126,32 +127,40 @@ def generate_no_stream(data: GenerationTaskReq, request: Request):
return out_string
-if __name__ == '__main__':
+if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
- 'pretrained',
- help='Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.')
- parser.add_argument('--quant',
- choices=['8bit', '4bit'],
- default=None,
- help='Quantization mode. Default: None (no quantization, fp16).')
+ "pretrained",
+ help="Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.",
+ )
parser.add_argument(
- '--gptq_checkpoint',
+ "--quant",
+ choices=["8bit", "4bit"],
default=None,
- help='Path to GPTQ checkpoint. This is only useful when quantization mode is 4bit. Default: None.')
- parser.add_argument('--gptq_group_size',
- type=int,
- default=128,
- help='Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.')
- parser.add_argument('--http_host', default='0.0.0.0')
- parser.add_argument('--http_port', type=int, default=7070)
- parser.add_argument('--profanity_file',
- default=None,
- help='Path to profanity words list. It should be a JSON file containing a list of words.')
+ help="Quantization mode. Default: None (no quantization, fp16).",
+ )
+ parser.add_argument(
+ "--gptq_checkpoint",
+ default=None,
+ help="Path to GPTQ checkpoint. This is only useful when quantization mode is 4bit. Default: None.",
+ )
+ parser.add_argument(
+ "--gptq_group_size",
+ type=int,
+ default=128,
+ help="Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.",
+ )
+ parser.add_argument("--http_host", default="0.0.0.0")
+ parser.add_argument("--http_port", type=int, default=7070)
+ parser.add_argument(
+ "--profanity_file",
+ default=None,
+ help="Path to profanity words list. It should be a JSON file containing a list of words.",
+ )
args = parser.parse_args()
- if args.quant == '4bit':
- assert args.gptq_checkpoint is not None, 'Please specify a GPTQ checkpoint.'
+ if args.quant == "4bit":
+ assert args.gptq_checkpoint is not None, "Please specify a GPTQ checkpoint."
tokenizer = AutoTokenizer.from_pretrained(args.pretrained)
@@ -161,7 +170,7 @@ def generate_no_stream(data: GenerationTaskReq, request: Request):
censored_words = []
prompt_processor = ChatPromptProcessor(tokenizer, CONTEXT, MAX_LEN, censored_words=censored_words)
- if args.quant == '4bit':
+ if args.quant == "4bit":
with low_resource_init():
config = LlamaConfig.from_pretrained(args.pretrained)
model = LlamaForCausalLM(config)
@@ -170,12 +179,12 @@ def generate_no_stream(data: GenerationTaskReq, request: Request):
else:
model = LlamaForCausalLM.from_pretrained(
args.pretrained,
- load_in_8bit=(args.quant == '8bit'),
+ load_in_8bit=(args.quant == "8bit"),
torch_dtype=torch.float16,
device_map="auto",
)
- if args.quant != '8bit':
- model.half() # seems to fix bugs for some users.
+ if args.quant != "8bit":
+ model.half() # seems to fix bugs for some users.
model.eval()
config = uvicorn.Config(app, host=args.http_host, port=args.http_port)
diff --git a/applications/Chat/inference/tests/test_chat_prompt.py b/applications/Chat/inference/tests/test_chat_prompt.py
index 23028d4959cb..9835e71894c6 100644
--- a/applications/Chat/inference/tests/test_chat_prompt.py
+++ b/applications/Chat/inference/tests/test_chat_prompt.py
@@ -3,41 +3,49 @@
from transformers import AutoTokenizer
from utils import ChatPromptProcessor, Dialogue
-CONTEXT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.'
-tokenizer = AutoTokenizer.from_pretrained(os.environ['PRETRAINED_PATH'])
+CONTEXT = "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions."
+tokenizer = AutoTokenizer.from_pretrained(os.environ["PRETRAINED_PATH"])
samples = [
- ([
- Dialogue(
- instruction='Who is the best player in the history of NBA?',
- response='The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
- ),
- Dialogue(instruction='continue this talk', response=''),
- ], 128,
- 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\nThe best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1\n\n### Instruction:\ncontinue this talk\n\n### Response:\n'
+ (
+ [
+ Dialogue(
+ instruction="Who is the best player in the history of NBA?",
+ response="The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1",
+ ),
+ Dialogue(instruction="continue this talk", response=""),
+ ],
+ 128,
+ "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\nThe best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1\n\n### Instruction:\ncontinue this talk\n\n### Response:\n",
),
- ([
- Dialogue(
- instruction='Who is the best player in the history of NBA?',
- response='The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
- ),
- Dialogue(instruction='continue this talk', response=''),
- ], 200,
- 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this talk\n\n### Response:\n'
+ (
+ [
+ Dialogue(
+ instruction="Who is the best player in the history of NBA?",
+ response="The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1",
+ ),
+ Dialogue(instruction="continue this talk", response=""),
+ ],
+ 200,
+ "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this talk\n\n### Response:\n",
),
- ([
- Dialogue(
- instruction='Who is the best player in the history of NBA?',
- response='The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
- ),
- Dialogue(instruction='continue this talk', response=''),
- ], 211,
- 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this\n\n### Response:\n'
+ (
+ [
+ Dialogue(
+ instruction="Who is the best player in the history of NBA?",
+ response="The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1",
+ ),
+ Dialogue(instruction="continue this talk", response=""),
+ ],
+ 211,
+ "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this\n\n### Response:\n",
),
- ([
- Dialogue(instruction='Who is the best player in the history of NBA?', response=''),
- ], 128,
- 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\n'
+ (
+ [
+ Dialogue(instruction="Who is the best player in the history of NBA?", response=""),
+ ],
+ 128,
+ "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\n",
),
]
@@ -49,5 +57,5 @@ def test_chat_prompt_processor():
assert prompt == result
-if __name__ == '__main__':
+if __name__ == "__main__":
test_chat_prompt_processor()
diff --git a/applications/Chat/inference/utils.py b/applications/Chat/inference/utils.py
index e8e7b05ac719..af018adf6e9d 100644
--- a/applications/Chat/inference/utils.py
+++ b/applications/Chat/inference/utils.py
@@ -20,9 +20,9 @@
from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper
-def prepare_logits_processor(top_k: Optional[int] = None,
- top_p: Optional[float] = None,
- temperature: Optional[float] = None) -> LogitsProcessorList:
+def prepare_logits_processor(
+ top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None
+) -> LogitsProcessorList:
processor_list = LogitsProcessorList()
if temperature is not None and temperature != 1.0:
processor_list.append(TemperatureLogitsWarper(temperature))
@@ -41,29 +41,30 @@ def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool:
return unfinished_sequences.max() == 0
-def sample_streamingly(model: nn.Module,
- input_ids: torch.Tensor,
- max_generate_tokens: int,
- early_stopping: bool = False,
- eos_token_id: Optional[int] = None,
- pad_token_id: Optional[int] = None,
- top_k: Optional[int] = None,
- top_p: Optional[float] = None,
- temperature: Optional[float] = None,
- prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
- update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
- **model_kwargs) -> Generator:
-
+def sample_streamingly(
+ model: nn.Module,
+ input_ids: torch.Tensor,
+ max_generate_tokens: int,
+ early_stopping: bool = False,
+ eos_token_id: Optional[int] = None,
+ pad_token_id: Optional[int] = None,
+ top_k: Optional[int] = None,
+ top_p: Optional[float] = None,
+ temperature: Optional[float] = None,
+ prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
+ update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
+ **model_kwargs,
+) -> Generator:
logits_processor = prepare_logits_processor(top_k, top_p, temperature)
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
for _ in range(max_generate_tokens):
- model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {
- 'input_ids': input_ids
- }
+ model_inputs = (
+ prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {"input_ids": input_ids}
+ )
outputs = model(**model_inputs)
- next_token_logits = outputs['logits'][:, -1, :]
+ next_token_logits = outputs["logits"][:, -1, :]
# pre-process distribution
next_token_logits = logits_processor(input_ids, next_token_logits)
# sample
@@ -107,25 +108,26 @@ def update_model_kwargs_fn(outputs: dict, **model_kwargs) -> dict:
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
- [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
+ )
return model_kwargs
class Dialogue(BaseModel):
- instruction: str = Field(min_length=1, example='Count up from 1 to 500.')
- response: str = Field(example='')
+ instruction: str = Field(min_length=1, example="Count up from 1 to 500.")
+ response: str = Field(example="")
-def _format_dialogue(instruction: str, response: str = ''):
- return f'\n\n### Instruction:\n{instruction}\n\n### Response:\n{response}'
+def _format_dialogue(instruction: str, response: str = ""):
+ return f"\n\n### Instruction:\n{instruction}\n\n### Response:\n{response}"
-STOP_PAT = re.compile(r'(###|instruction:).*', flags=(re.I | re.S))
+STOP_PAT = re.compile(r"(###|instruction:).*", flags=(re.I | re.S))
class ChatPromptProcessor:
- SAFE_RESPONSE = 'The input/response contains inappropriate content, please rephrase your prompt.'
+ SAFE_RESPONSE = "The input/response contains inappropriate content, please rephrase your prompt."
def __init__(self, tokenizer, context: str, max_len: int = 2048, censored_words: List[str] = []):
self.tokenizer = tokenizer
@@ -138,42 +140,48 @@ def __init__(self, tokenizer, context: str, max_len: int = 2048, censored_words:
def preprocess_prompt(self, history: List[Dialogue], max_new_tokens: int) -> str:
if self.context_len is None:
- self.context_len = len(self.tokenizer(self.context)['input_ids'])
+ self.context_len = len(self.tokenizer(self.context)["input_ids"])
if self.dialogue_placeholder_len is None:
self.dialogue_placeholder_len = len(
- self.tokenizer(_format_dialogue(''), add_special_tokens=False)['input_ids'])
+ self.tokenizer(_format_dialogue(""), add_special_tokens=False)["input_ids"]
+ )
prompt = self.context
# the last dialogue must be in the prompt
last_dialogue = history.pop()
# the response of the last dialogue is empty
- assert last_dialogue.response == ''
- if len(self.tokenizer(_format_dialogue(last_dialogue.instruction), add_special_tokens=False)
- ['input_ids']) + max_new_tokens + self.context_len >= self.max_len:
+ assert last_dialogue.response == ""
+ if (
+ len(self.tokenizer(_format_dialogue(last_dialogue.instruction), add_special_tokens=False)["input_ids"])
+ + max_new_tokens
+ + self.context_len
+ >= self.max_len
+ ):
# to avoid truncate placeholder, apply truncate to the original instruction
- instruction_truncated = self.tokenizer(last_dialogue.instruction,
- add_special_tokens=False,
- truncation=True,
- max_length=(self.max_len - max_new_tokens - self.context_len -
- self.dialogue_placeholder_len))['input_ids']
+ instruction_truncated = self.tokenizer(
+ last_dialogue.instruction,
+ add_special_tokens=False,
+ truncation=True,
+ max_length=(self.max_len - max_new_tokens - self.context_len - self.dialogue_placeholder_len),
+ )["input_ids"]
instruction_truncated = self.tokenizer.decode(instruction_truncated).lstrip()
prompt += _format_dialogue(instruction_truncated)
return prompt
- res_len = self.max_len - max_new_tokens - len(self.tokenizer(prompt)['input_ids'])
+ res_len = self.max_len - max_new_tokens - len(self.tokenizer(prompt)["input_ids"])
rows = []
for dialogue in history[::-1]:
text = _format_dialogue(dialogue.instruction, dialogue.response)
- cur_len = len(self.tokenizer(text, add_special_tokens=False)['input_ids'])
+ cur_len = len(self.tokenizer(text, add_special_tokens=False)["input_ids"])
if res_len - cur_len < 0:
break
res_len -= cur_len
rows.insert(0, text)
- prompt += ''.join(rows) + _format_dialogue(last_dialogue.instruction)
+ prompt += "".join(rows) + _format_dialogue(last_dialogue.instruction)
return prompt
def postprocess_output(self, output: str) -> str:
- output = STOP_PAT.sub('', output)
+ output = STOP_PAT.sub("", output)
return output.strip()
def has_censored_words(self, text: str) -> bool:
@@ -184,7 +192,6 @@ def has_censored_words(self, text: str) -> bool:
class LockedIterator:
-
def __init__(self, it, lock: Lock) -> None:
self.lock = lock
self.it = iter(it)
diff --git a/applications/Chat/requirements-test.txt b/applications/Chat/requirements-test.txt
index eb1a77875acb..93d48bcb6f79 100644
--- a/applications/Chat/requirements-test.txt
+++ b/applications/Chat/requirements-test.txt
@@ -1,2 +1,2 @@
pytest
-colossalai==0.3.1
\ No newline at end of file
+colossalai==0.3.3
diff --git a/applications/Chat/requirements.txt b/applications/Chat/requirements.txt
index e5f5ca0932a8..e56aaca0e7cb 100644
--- a/applications/Chat/requirements.txt
+++ b/applications/Chat/requirements.txt
@@ -2,7 +2,7 @@ transformers>=4.20.1
tqdm
datasets
loralib
-colossalai==0.3.1
+colossalai==0.3.3
torch<2.0.0, >=1.12.1
langchain
tokenizers
@@ -11,3 +11,4 @@ sse_starlette
wandb
sentencepiece
gpustat
+tensorboard
diff --git a/applications/Chat/setup.py b/applications/Chat/setup.py
index a285a6dff4bf..eb44b6203ef8 100644
--- a/applications/Chat/setup.py
+++ b/applications/Chat/setup.py
@@ -2,40 +2,42 @@
def fetch_requirements(path):
- with open(path, 'r') as fd:
+ with open(path, "r") as fd:
return [r.strip() for r in fd.readlines()]
def fetch_readme():
- with open('README.md', encoding='utf-8') as f:
+ with open("README.md", encoding="utf-8") as f:
return f.read()
def fetch_version():
- with open('version.txt', 'r') as f:
+ with open("version.txt", "r") as f:
return f.read().strip()
setup(
- name='coati',
+ name="coati",
version=fetch_version(),
- packages=find_packages(exclude=(
- 'tests',
- 'benchmarks',
- '*.egg-info',
- )),
- description='Colossal-AI Talking Intelligence',
+ packages=find_packages(
+ exclude=(
+ "tests",
+ "benchmarks",
+ "*.egg-info",
+ )
+ ),
+ description="Colossal-AI Talking Intelligence",
long_description=fetch_readme(),
- long_description_content_type='text/markdown',
- license='Apache Software License 2.0',
- url='https://github.com/hpcaitech/Coati',
- install_requires=fetch_requirements('requirements.txt'),
- python_requires='>=3.6',
+ long_description_content_type="text/markdown",
+ license="Apache Software License 2.0",
+ url="https://github.com/hpcaitech/Coati",
+ install_requires=fetch_requirements("requirements.txt"),
+ python_requires=">=3.6",
classifiers=[
- 'Programming Language :: Python :: 3',
- 'License :: OSI Approved :: Apache Software License',
- 'Environment :: GPU :: NVIDIA CUDA',
- 'Topic :: Scientific/Engineering :: Artificial Intelligence',
- 'Topic :: System :: Distributed Computing',
+ "Programming Language :: Python :: 3",
+ "License :: OSI Approved :: Apache Software License",
+ "Environment :: GPU :: NVIDIA CUDA",
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
+ "Topic :: System :: Distributed Computing",
],
)
diff --git a/applications/Chat/tests/test_checkpoint.py b/applications/Chat/tests/test_checkpoint.py
index 3a3bf5b19cb8..9c08aa36c9b4 100644
--- a/applications/Chat/tests/test_checkpoint.py
+++ b/applications/Chat/tests/test_checkpoint.py
@@ -22,25 +22,21 @@ def get_data(batch_size: int, seq_len: int = 10) -> dict:
return dict(input_ids=input_ids, attention_mask=attention_mask)
-def train_step(strategy: Strategy,
- actor: GPTActor,
- actor_optim: HybridAdam,
- batch_size: int = 8):
+def train_step(strategy: Strategy, actor: GPTActor, actor_optim: HybridAdam, batch_size: int = 8):
data = get_data(batch_size)
action_mask = torch.ones_like(data["attention_mask"], dtype=torch.bool)
- actor_output = actor(data["input_ids"], data["attention_mask"])
- action_log_probs = calc_action_log_probs(actor_output, data["input_ids"], action_mask.size(1))
+ actor_logits = actor(data["input_ids"], data["attention_mask"])["logits"]
+ action_log_probs = calc_action_log_probs(actor_logits, data["input_ids"], action_mask.size(1))
loss = action_log_probs.sum()
strategy.backward(loss, actor, actor_optim)
strategy.optimizer_step(actor_optim)
-def run_test_checkpoint(strategy_name: str,
- shard: bool):
+def run_test_checkpoint(strategy_name: str, shard: bool):
if strategy_name == "ddp":
strategy = DDPStrategy()
elif strategy_name == "colossalai_gemini":
- strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5)
+ strategy = GeminiStrategy(placement_policy="auto", initial_scale=2**5)
elif strategy_name == "colossalai_zero2":
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
else:
@@ -60,12 +56,10 @@ def run_test_checkpoint(strategy_name: str,
dist.broadcast_object_list(rank0_dirname)
rank0_dirname = rank0_dirname[0]
- model_path = os.path.join(
- rank0_dirname, "model" if shard else f"model.pt")
- strategy.save_model(actor, model_path, only_rank0=not shard)
- optim_path = os.path.join(
- rank0_dirname, "optim" if shard else "optim.pt")
- strategy.save_optimizer(actor_optim, optim_path, only_rank0=not shard)
+ model_path = os.path.join(rank0_dirname, "model" if shard else f"model.pt")
+ strategy.save_model(actor, model_path)
+ optim_path = os.path.join(rank0_dirname, "optim" if shard else "optim.pt")
+ strategy.save_optimizer(actor_optim, optim_path)
dist.barrier()
strategy.load_model(actor, model_path, strict=False)
@@ -75,11 +69,7 @@ def run_test_checkpoint(strategy_name: str,
train_step(strategy, actor, actor_optim)
-def run_dist(rank: int,
- world_size: int,
- port: int,
- strategy_name: str,
- shard: bool):
+def run_dist(rank: int, world_size: int, port: int, strategy_name: str, shard: bool):
os.environ["RANK"] = str(rank)
os.environ["LOCAL_RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
@@ -93,13 +83,8 @@ def run_dist(rank: int,
@pytest.mark.parametrize("strategy_name", ["ddp", "colossalai_gemini", "colossalai_zero2"])
@pytest.mark.parametrize("shard", [False, True])
@rerun_if_address_is_in_use()
-def test_checkpoint(world_size: int,
- strategy_name: str,
- shard: bool):
- spawn(run_dist,
- world_size,
- strategy_name=strategy_name,
- shard=shard)
+def test_checkpoint(world_size: int, strategy_name: str, shard: bool):
+ spawn(run_dist, world_size, strategy_name=strategy_name, shard=shard)
if __name__ == "__main__":
diff --git a/applications/Chat/tests/test_dataset.py b/applications/Chat/tests/test_dataset.py
index f9dee1bae935..ec61bbb13fd7 100644
--- a/applications/Chat/tests/test_dataset.py
+++ b/applications/Chat/tests/test_dataset.py
@@ -8,62 +8,40 @@
from coati.dataset.prompt_dataset import PromptDataset
from coati.dataset.reward_dataset import HhRlhfDataset, RmStaticDataset
from coati.dataset.sft_dataset import IGNORE_INDEX, SFTDataset, SupervisedDataset
+from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
from datasets import load_dataset
from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer, PreTrainedTokenizer
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
-from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
+
SFT_DATASET = [
{
- "instruction":
- "Provide a list of the top 10 most popular mobile games in Asia",
- "input":
- "",
- "output":
- "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved",
- "id":
- 0
+ "instruction": "Provide a list of the top 10 most popular mobile games in Asia",
+ "input": "",
+ "output": "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved",
+ "id": 0,
},
{
- "instruction":
- "Please provide an action plan for reducing carbon footprint on a corporate level",
- "input":
- "",
- "output":
- "An action plan for reducing carbon footprint on a corporate level could include switching to renewable energy sources (e.g., solar or wind powered), investing in energy efficient infrastructure, implementing green production practices (e.g., reduce water and paper usage, dispose of materials responsibly, recycle/reuse materials), reducing or eliminating travel, encouraging employees to be conscious of and reduce their own energy usage, and incentivizing employees to participate in sustainable work practices.",
- "id":
- 1
+ "instruction": "Please provide an action plan for reducing carbon footprint on a corporate level",
+ "input": "",
+ "output": "An action plan for reducing carbon footprint on a corporate level could include switching to renewable energy sources (e.g., solar or wind powered), investing in energy efficient infrastructure, implementing green production practices (e.g., reduce water and paper usage, dispose of materials responsibly, recycle/reuse materials), reducing or eliminating travel, encouraging employees to be conscious of and reduce their own energy usage, and incentivizing employees to participate in sustainable work practices.",
+ "id": 1,
},
{
- "instruction":
- "Write a persuasive email to your boss explaining why you should have a pay raise",
- "input":
- "",
- "output":
- "Dear [Boss], \n\nI hope this message finds you well. I am writing to request a pay raise.\n\nAs you know, I have been a dedicated and hardworking member of the team since I started working here [insert number] of months/years ago. My enthusiasm and passion for my job has remained consistent over the years, and I have always given 100% to my role. \n\nI understand that the current financial situation is challenging, however, I would sincerely appreciate you taking the time to consider my request. I believe that my dedication to the job and the value that I bring to the organization warrants a raise. I work diligently and am confident that I can continue to be an asset to the company. \n\nI hope my request is taken into account and I thank you in advance for your understanding. I look forward to our conversation. \n\nSincerely,\n[Your Name]",
- "id":
- 2
+ "instruction": "Write a persuasive email to your boss explaining why you should have a pay raise",
+ "input": "",
+ "output": "Dear [Boss], \n\nI hope this message finds you well. I am writing to request a pay raise.\n\nAs you know, I have been a dedicated and hardworking member of the team since I started working here [insert number] of months/years ago. My enthusiasm and passion for my job has remained consistent over the years, and I have always given 100% to my role. \n\nI understand that the current financial situation is challenging, however, I would sincerely appreciate you taking the time to consider my request. I believe that my dedication to the job and the value that I bring to the organization warrants a raise. I work diligently and am confident that I can continue to be an asset to the company. \n\nI hope my request is taken into account and I thank you in advance for your understanding. I look forward to our conversation. \n\nSincerely,\n[Your Name]",
+ "id": 2,
},
]
PROMPT_DATASET = [
{
- "instruction":
- "Edit this paragraph to make it more concise: \"Yesterday, I went to the store and bought some things. Then, I came home and put them away. After that, I went for a walk and met some friends.\"",
- "id":
- 0
- },
- {
- "instruction": "Write a descriptive paragraph about a memorable vacation you went on",
- "id": 1
- },
- {
- "instruction": "Write a persuasive essay arguing why homework should be banned in schools",
- "id": 2
- },
- {
- "instruction": "Create a chart comparing the statistics on student debt in the United States.",
- "id": 3
+ "instruction": 'Edit this paragraph to make it more concise: "Yesterday, I went to the store and bought some things. Then, I came home and put them away. After that, I went for a walk and met some friends."',
+ "id": 0,
},
+ {"instruction": "Write a descriptive paragraph about a memorable vacation you went on", "id": 1},
+ {"instruction": "Write a persuasive essay arguing why homework should be banned in schools", "id": 2},
+ {"instruction": "Create a chart comparing the statistics on student debt in the United States.", "id": 3},
]
@@ -120,10 +98,12 @@ def test_prompt_dataset(model: str, max_datasets_size: int, max_length: int):
json.dump(PROMPT_DATASET, f)
tokenizer = make_tokenizer(model)
assert tokenizer.padding_side in ("left", "right")
- prompt_dataset = PromptDataset(data_path=os.path.join(tmp_dir, dataset_name),
- tokenizer=tokenizer,
- max_datasets_size=max_datasets_size,
- max_length=max_length)
+ prompt_dataset = PromptDataset(
+ data_path=os.path.join(tmp_dir, dataset_name),
+ tokenizer=tokenizer,
+ max_datasets_size=max_datasets_size,
+ max_length=max_length,
+ )
assert len(prompt_dataset) == min(max_datasets_size, len(PROMPT_DATASET))
for i in range(len(prompt_dataset)):
assert isinstance(prompt_dataset[i], dict)
@@ -137,14 +117,14 @@ def test_prompt_dataset(model: str, max_datasets_size: int, max_length: int):
@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"])
-@pytest.mark.parametrize(["dataset_path", "subset"], [("Anthropic/hh-rlhf", "harmless-base"),
- ("Dahoas/rm-static", None)])
+@pytest.mark.parametrize(
+ ["dataset_path", "subset"], [("Anthropic/hh-rlhf", "harmless-base"), ("Dahoas/rm-static", None)]
+)
@pytest.mark.parametrize("max_datasets_size", [32])
@pytest.mark.parametrize("max_length", [32, 1024])
def test_reward_dataset(model: str, dataset_path: str, subset: Optional[str], max_datasets_size: int, max_length: int):
data = load_dataset(dataset_path, data_dir=subset)
- assert max_datasets_size <= len(data["train"]) \
- and max_datasets_size <= len(data["test"])
+ assert max_datasets_size <= len(data["train"]) and max_datasets_size <= len(data["test"])
train_data = data["train"].select(range(max_datasets_size))
test_data = data["test"].select(range(max_datasets_size))
tokenizer = make_tokenizer(model)
@@ -162,8 +142,7 @@ def test_reward_dataset(model: str, dataset_path: str, subset: Optional[str], ma
assert len(train_dataset) == len(test_dataset) == max_datasets_size
for i in range(max_datasets_size):
chosen_ids, c_mask, reject_ids, r_mask = train_dataset[i]
- assert chosen_ids.shape == c_mask.shape == \
- reject_ids.shape == r_mask.shape == torch.Size([max_length])
+ assert chosen_ids.shape == c_mask.shape == reject_ids.shape == r_mask.shape == torch.Size([max_length])
c_mask = c_mask.to(torch.bool)
r_mask = r_mask.to(torch.bool)
if chosen_ids.masked_select(c_mask)[-1] == tokenizer.eos_token_id:
@@ -180,8 +159,7 @@ def test_reward_dataset(model: str, dataset_path: str, subset: Optional[str], ma
assert torch.all(r_mask)
chosen_ids, c_mask, reject_ids, r_mask = test_dataset[i]
- assert chosen_ids.shape == c_mask.shape == \
- reject_ids.shape == r_mask.shape == torch.Size([max_length])
+ assert chosen_ids.shape == c_mask.shape == reject_ids.shape == r_mask.shape == torch.Size([max_length])
c_mask = c_mask.to(torch.bool)
r_mask = r_mask.to(torch.bool)
if chosen_ids.masked_select(c_mask)[-1] == tokenizer.eos_token_id:
@@ -198,7 +176,6 @@ def test_reward_dataset(model: str, dataset_path: str, subset: Optional[str], ma
assert torch.all(r_mask)
-
@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama", "chatglm"])
@pytest.mark.parametrize("dataset_path", ["yizhongw/self_instruct", None])
@pytest.mark.parametrize("max_dataset_size", [2])
@@ -214,10 +191,12 @@ def test_sft_dataset(model: str, dataset_path: Optional[str], max_dataset_size:
dataset_name = "sft_dataset.json"
with open(os.path.join(tmp_dir, dataset_name), "w") as f:
json.dump(SFT_DATASET, f)
- sft_dataset = SupervisedDataset(tokenizer=tokenizer,
- data_path=os.path.join(tmp_dir, dataset_name),
- max_datasets_size=max_dataset_size,
- max_length=max_length)
+ sft_dataset = SupervisedDataset(
+ tokenizer=tokenizer,
+ data_path=os.path.join(tmp_dir, dataset_name),
+ max_datasets_size=max_dataset_size,
+ max_length=max_length,
+ )
assert len(sft_dataset) == min(max_dataset_size, len(SFT_DATASET))
if isinstance(tokenizer, ChatGLMTokenizer):
@@ -227,20 +206,19 @@ def test_sft_dataset(model: str, dataset_path: Optional[str], max_dataset_size:
input_ids = sft_dataset[i]["input_ids"]
labels = sft_dataset[i]["labels"]
assert input_ids.shape == labels.shape == torch.Size([max_length])
-
+
ignore_mask = labels == IGNORE_INDEX
assert input_ids.masked_select(torch.logical_not(ignore_mask))[0] == tokenizer.bos_token_id
check_content(input_ids.masked_select(torch.logical_not(ignore_mask)), tokenizer, model)
return
-
+
for i in range(max_dataset_size):
assert isinstance(sft_dataset[i], dict)
assert list(sft_dataset[i].keys()) == ["input_ids", "labels", "attention_mask"]
input_ids = sft_dataset[i]["input_ids"]
labels = sft_dataset[i]["labels"]
attention_mask = sft_dataset[i]["attention_mask"].to(torch.bool)
- assert input_ids.shape == labels.shape == \
- attention_mask.shape == torch.Size([max_length])
+ assert input_ids.shape == labels.shape == attention_mask.shape == torch.Size([max_length])
if input_ids.masked_select(attention_mask)[-1] == tokenizer.eos_token_id:
check_content(input_ids.masked_select(attention_mask)[:-1], tokenizer, model)
assert torch.all(input_ids.masked_select(torch.logical_not(attention_mask)) == tokenizer.pad_token_id)
@@ -248,19 +226,16 @@ def test_sft_dataset(model: str, dataset_path: Optional[str], max_dataset_size:
check_content(input_ids.masked_select(attention_mask), tokenizer, model)
assert torch.all(attention_mask)
ignore_mask = labels == IGNORE_INDEX
- check_content(input_ids.masked_select(ignore_mask), tokenizer, model)
+ prompt_mask = torch.logical_and(ignore_mask, attention_mask)
+ check_content(input_ids.masked_select(prompt_mask), tokenizer, model)
+ assert torch.all(input_ids.masked_select(ignore_mask ^ prompt_mask) == tokenizer.pad_token_id)
if __name__ == "__main__":
test_sft_dataset(model="bloom", dataset_path="yizhongw/self_instruct", max_dataset_size=2, max_length=256)
- test_reward_dataset(model="gpt2",
- dataset_path="Anthropic/hh-rlhf",
- subset="harmless-base",
- max_datasets_size=8,
- max_length=256)
-
- test_prompt_dataset(model="opt",
- max_datasets_size=2,
- max_length=128)
+ test_reward_dataset(
+ model="gpt2", dataset_path="Anthropic/hh-rlhf", subset="harmless-base", max_datasets_size=8, max_length=256
+ )
+ test_prompt_dataset(model="opt", max_datasets_size=2, max_length=128)
diff --git a/applications/Chat/tests/test_experience.py b/applications/Chat/tests/test_experience.py
index 071e50b90e8e..a9591259800d 100644
--- a/applications/Chat/tests/test_experience.py
+++ b/applications/Chat/tests/test_experience.py
@@ -1,5 +1,5 @@
+import copy
import os
-from copy import deepcopy
import pytest
import torch
@@ -8,6 +8,7 @@
from coati.experience_maker import NaiveExperienceMaker
from coati.models.base import RewardModel
from coati.models.gpt import GPTActor, GPTCritic
+from coati.trainer.ppo import _set_default_generate_kwargs
from coati.trainer.strategies import DDPStrategy, GeminiStrategy
from coati.trainer.strategies.colossalai import LowLevelZeroStrategy
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
@@ -18,7 +19,7 @@
def get_data(batch_size: int, seq_len: int = 10) -> dict:
- input_ids = torch.randint(0, 50257, (batch_size, seq_len), device='cuda')
+ input_ids = torch.randint(0, 50257, (batch_size, seq_len), device="cuda")
attention_mask = torch.ones_like(input_ids)
return dict(input_ids=input_ids, attention_mask=attention_mask)
@@ -37,34 +38,43 @@ def make_and_consume_experience(strategy):
EXPERIENCE_BATCH_SIZE = 4
SAMPLE_BATCH_SIZE = 2
- if strategy == 'ddp':
+ if strategy == "ddp":
strategy = DDPStrategy()
- elif strategy == 'colossalai-zero2':
+ elif strategy == "colossalai-zero2":
strategy = LowLevelZeroStrategy()
- elif strategy == 'colossalai-gemini':
- strategy = GeminiStrategy(placement_policy='cuda')
+ elif strategy == "colossalai-gemini":
+ strategy = GeminiStrategy(placement_policy="static")
else:
raise ValueError(f'Unsupported strategy "{strategy}"')
- actor = GPTActor(config=GPT_CONFIG).cuda()
- critic = GPTCritic(config=GPT_CONFIG).cuda()
+ with strategy.model_init_context():
+ actor = GPTActor(config=GPT_CONFIG).cuda()
+ critic = GPTCritic(config=GPT_CONFIG).cuda()
- initial_model = deepcopy(actor)
- reward_model = RewardModel(deepcopy(critic.model)).cuda()
+ initial_model = GPTActor(config=GPT_CONFIG).cuda()
+ reward_model = RewardModel(model=copy.deepcopy(critic.model)).cuda()
- experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model)
+ actor, critic, initial_model, reward_model = strategy.prepare(actor, critic, initial_model, reward_model)
+
+ class MockTokenizer:
+ def __init__(self):
+ self.padding_side = "left"
+ self.eos_token_id = 0
+ self.pad_token_id = 0
+
+ tokenizer = MockTokenizer()
+ experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, tokenizer)
data_buffer = NaiveExperienceBuffer(SAMPLE_BATCH_SIZE, cpu_offload=False)
+ generate_kwargs = dict(do_sample=True, max_length=16)
+ generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
+
# experience of all ranks should be the same
for _ in range(2):
data = get_data(EXPERIENCE_BATCH_SIZE)
- assert gather_and_equal(data['input_ids'])
- assert gather_and_equal(data['attention_mask'])
- experience = experience_maker.make_experience(**data,
- do_sample=True,
- max_length=16,
- eos_token_id=50256,
- pad_token_id=50256)
+ assert gather_and_equal(data["input_ids"])
+ assert gather_and_equal(data["attention_mask"])
+ experience = experience_maker.make_experience(**data, do_sample=True, max_length=16)
assert gather_and_equal(experience.sequences)
assert gather_and_equal(experience.action_log_probs)
assert gather_and_equal(experience.values)
@@ -75,7 +85,7 @@ def make_and_consume_experience(strategy):
data_buffer.append(experience)
# data buffer's data should be the same
- buffer_size = torch.tensor([len(data_buffer)], device='cuda')
+ buffer_size = torch.tensor([len(data_buffer)], device="cuda")
assert gather_and_equal(buffer_size)
for item in data_buffer.items:
assert gather_and_equal(item.sequences)
@@ -88,7 +98,7 @@ def make_and_consume_experience(strategy):
# dataloader of each rank should have the same size and different batch
dataloader = strategy.setup_dataloader(data_buffer)
- dataloader_size = torch.tensor([len(dataloader)], device='cuda')
+ dataloader_size = torch.tensor([len(dataloader)], device="cuda")
assert gather_and_equal(dataloader_size)
for experience in dataloader:
assert not gather_and_equal(experience.sequences)
@@ -100,21 +110,21 @@ def make_and_consume_experience(strategy):
def run_dist(rank, world_size, port, strategy):
- os.environ['RANK'] = str(rank)
- os.environ['LOCAL_RANK'] = str(rank)
- os.environ['WORLD_SIZE'] = str(world_size)
- os.environ['MASTER_ADDR'] = 'localhost'
- os.environ['MASTER_PORT'] = str(port)
+ os.environ["RANK"] = str(rank)
+ os.environ["LOCAL_RANK"] = str(rank)
+ os.environ["WORLD_SIZE"] = str(world_size)
+ os.environ["MASTER_ADDR"] = "localhost"
+ os.environ["MASTER_PORT"] = str(port)
make_and_consume_experience(strategy)
@pytest.mark.dist
-@pytest.mark.parametrize('world_size', [2])
-@pytest.mark.parametrize('strategy', ['ddp', 'colossalai-zero2', 'colossalai-gemini'])
+@pytest.mark.parametrize("world_size", [2])
+@pytest.mark.parametrize("strategy", ["ddp", "colossalai-zero2", "colossalai-gemini"])
@rerun_if_address_is_in_use()
def test_experience(world_size, strategy):
spawn(run_dist, world_size, strategy=strategy)
-if __name__ == '__main__':
- test_experience(2, 'colossalai')
+if __name__ == "__main__":
+ test_experience(2, "colossalai-zero2")
diff --git a/applications/Chat/tests/test_models.py b/applications/Chat/tests/test_models.py
index b98b3615cd28..b2c22ac6a3b9 100644
--- a/applications/Chat/tests/test_models.py
+++ b/applications/Chat/tests/test_models.py
@@ -6,15 +6,16 @@
import torch.nn as nn
from coati.models.base import Actor, Critic, RewardModel, get_base_model
from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
+from coati.models.chatglm import ChatGLMActor
+from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
from coati.models.generation import generate
from coati.models.gpt import GPTRM, GPTActor, GPTCritic
-from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
-from coati.models.chatglm import ChatGLMActor
+from coati.models.llama import LlamaActor
from coati.models.lora import LoraLinear, convert_to_lora_module
from coati.models.loss import GPTLMLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
from coati.models.opt import OPTRM, OPTActor, OPTCritic
-from coati.models.utils import calc_action_log_probs, compute_reward, masked_mean
-from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
+from coati.models.utils import calc_action_log_probs, masked_mean
+
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seq_len", [32])
@@ -23,23 +24,34 @@
[
lambda: BLOOMActor(),
lambda: GPTActor(),
- # HACK: skip llama due to long execution time
- # lambda: LlamaActor(),
- lambda: OPTActor(),
- # lambda: ChatGLMActor(),
-])
-
-@pytest.mark.parametrize("generate_kwargs", [{
- "max_length": 64,
- "use_cache": True,
- "do_sample": True,
- "temperature": 1.0,
- "top_k": 50,
-}])
+ # HACK: skip llama due to long execution time
+ # lambda: LlamaActor(),
+ lambda: OPTActor(),
+ ],
+)
+@pytest.mark.parametrize(
+ "generate_kwargs",
+ [
+ {
+ "max_length": 64,
+ "use_cache": True,
+ "do_sample": True,
+ "temperature": 1.0,
+ "top_k": 50,
+ }
+ ],
+)
def test_generation(actor_maker: Callable[[], Actor], batch_size: int, seq_len: int, generate_kwargs: Dict[str, Any]):
+ class MockTokenizer:
+ def __init__(self):
+ self.padding_side = "left"
+ self.eos_token_id = 0
+ self.pad_token_id = 0
+
actor = actor_maker()
input_ids = torch.randint(0, 100, (batch_size, seq_len)).cuda()
- sequences = generate(actor.cuda(), input_ids, **generate_kwargs)
+ tokenizer = MockTokenizer()
+ sequences = generate(actor.cuda(), input_ids, tokenizer, **generate_kwargs)
assert sequences.shape == (batch_size, generate_kwargs["max_length"])
@@ -49,26 +61,12 @@ def test_utils():
assert fn_output.dim() == 0
assert torch.allclose(fn_output, torch.tensor(1.0))
- batch_size = 4
- num_labels = 10
- fn_input = {
- "r": torch.ones((batch_size,)),
- "kl_coef": 1.0,
- "log_probs": torch.randn((batch_size, num_labels)),
- "log_probs_base": torch.randn((batch_size, num_labels)),
- "action_mask": torch.randint(0, 2, (batch_size, num_labels))
- }
- fn_output = compute_reward(**fn_input)
- assert fn_output.shape == (batch_size,)
-
batch_size = 4
seq_len = 32
num_labels = 10
num_actions = 2
fn_input = {
- "output": {
- "logits": torch.randn((batch_size, seq_len, num_labels))
- },
+ "logits": torch.randn((batch_size, seq_len, num_labels)),
"sequences": torch.randint(0, num_labels, (batch_size, seq_len)),
"num_actions": num_actions,
}
@@ -105,8 +103,9 @@ def test_lora(lora_rank: int, num_dim: int, num_layers: int):
assert isinstance(lora_model[i], LoraLinear)
assert torch.allclose(old_model[i].weight, lora_model[i].weight)
assert torch.allclose(old_model[i].bias, lora_model[i].bias)
- assert not torch.allclose(old_model[i].lora_B @ old_model[i].lora_A,
- lora_model[i].lora_B @ lora_model[i].lora_A)
+ assert not torch.allclose(
+ old_model[i].lora_B @ old_model[i].lora_A, lora_model[i].lora_B @ lora_model[i].lora_A
+ )
@pytest.mark.parametrize("batch_size", [8])
@@ -116,54 +115,59 @@ def test_lora(lora_rank: int, num_dim: int, num_layers: int):
[
lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()),
lambda: (GPTActor(), GPTCritic(), GPTRM()),
- # HACK: skip llama due to long execution time
- # lambda: (LlamaActor(), LlamaCritic(), LlamaRM()),
- lambda: (OPTActor(), OPTCritic(), OPTRM()),
- lambda: (ChatGLMActor(), None, None),
-])
+ # HACK: skip llama due to long execution time
+ # lambda: (LlamaActor(), LlamaCritic(), LlamaRM()),
+ lambda: (OPTActor(), OPTCritic(), OPTRM()),
+ lambda: (ChatGLMActor(), None, None),
+ ],
+)
@torch.no_grad()
-def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]],
- batch_size: int,
- seq_len: int):
+def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]], batch_size: int, seq_len: int):
actor_input = {
"input_ids": torch.randint(0, 100, (batch_size, seq_len)),
- "attention_mask": torch.randint(0, 2, (batch_size, seq_len))
+ "attention_mask": torch.randint(0, 2, (batch_size, seq_len)),
}
critic_input = {
"sequences": torch.randint(0, 100, (batch_size, seq_len)),
- "action_mask": torch.randint(0, 2, (batch_size, seq_len)),
- "attention_mask": torch.randint(0, 2, (batch_size, seq_len))
+ "attention_mask": torch.randint(0, 2, (batch_size, seq_len)),
}
rm_input = {
"sequences": torch.randint(0, 100, (batch_size, seq_len)),
- "attention_mask": torch.randint(0, 2, (batch_size, seq_len))
+ "attention_mask": torch.randint(0, 2, (batch_size, seq_len)),
}
actor, critic, rm = models_maker()
if isinstance(actor, ChatGLMActor):
actor = actor.float()
- tokenizer = ChatGLMTokenizer.from_pretrained( "THUDM/chatglm-6b", trust_remote_code=True)
+ tokenizer = ChatGLMTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
chatglm_special_token = torch.tensor([tokenizer.gmask_token_id, tokenizer.bos_token_id]).repeat(batch_size, 1)
- actor_input ={
- "input_ids": torch.cat((torch.randint(0, 100, (batch_size, seq_len//2)), chatglm_special_token, torch.randint(0, 100, (batch_size, seq_len//2 - 2))), dim=1),
- "attention_mask": torch.randint(0, 2, (batch_size, 1, seq_len, seq_len))
- }
+ actor_input = {
+ "input_ids": torch.cat(
+ (
+ torch.randint(0, 100, (batch_size, seq_len // 2)),
+ chatglm_special_token,
+ torch.randint(0, 100, (batch_size, seq_len // 2 - 2)),
+ ),
+ dim=1,
+ ),
+ "attention_mask": torch.randint(0, 2, (batch_size, 1, seq_len, seq_len)),
+ }
assert isinstance(actor, Actor)
- base_actor_model = get_base_model(actor)
+ get_base_model(actor)
actor_output = actor(**actor_input)
assert actor_output.logits.shape[:2] == (batch_size, seq_len)
if critic:
assert isinstance(critic, Critic)
- base_critic_model = get_base_model(critic)
+ get_base_model(critic)
critic_output = critic(**critic_input)
- assert critic_output.shape == (batch_size, )
-
+ assert critic_output.shape == (batch_size,)
+
if rm:
assert isinstance(rm, RewardModel)
- base_rm_model = get_base_model(rm)
+ get_base_model(rm)
rm_output = rm(**rm_input)
- assert rm_output.shape == (batch_size, )
+ assert rm_output.shape == (batch_size,)
@pytest.mark.parametrize("batch_size", [16])
@@ -173,39 +177,59 @@ def test_loss(batch_size: int, seq_len: int, num_labels: int):
loss = GPTLMLoss()
loss_input = {
"logits": torch.randn(batch_size, seq_len, num_labels),
- "labels": torch.randint(0, num_labels, (batch_size, seq_len))
+ "labels": torch.randint(0, num_labels, (batch_size, seq_len)),
}
- loss_output = loss(**loss_input)
+ loss(**loss_input)
loss = PolicyLoss()
loss_input = {
- "log_probs": torch.randn(batch_size,),
- "old_log_probs": torch.randn(batch_size,),
- "advantages": torch.randn(batch_size,)
+ "log_probs": torch.randn(
+ batch_size,
+ ),
+ "old_log_probs": torch.randn(
+ batch_size,
+ ),
+ "advantages": torch.randn(
+ batch_size,
+ ),
}
- loss_output = loss(**loss_input)
+ loss(**loss_input)
loss = ValueLoss()
loss_input = {
- "values": torch.randn(batch_size,),
- "old_values": torch.randn(batch_size,),
- "reward": torch.randn(batch_size,)
+ "values": torch.randn(
+ batch_size,
+ ),
+ "old_values": torch.randn(
+ batch_size,
+ ),
+ "reward": torch.randn(
+ batch_size,
+ ),
}
- loss_output = loss(**loss_input)
+ loss(**loss_input)
loss = LogSigLoss()
loss_input = {
- "chosen_reward": torch.randn(batch_size,),
- "reject_reward": torch.randn(batch_size,),
+ "chosen_reward": torch.randn(
+ batch_size,
+ ),
+ "reject_reward": torch.randn(
+ batch_size,
+ ),
}
- loss_output = loss(**loss_input)
+ loss(**loss_input)
loss = LogExpLoss()
loss_input = {
- "chosen_reward": torch.randn(batch_size,),
- "reject_reward": torch.randn(batch_size,),
+ "chosen_reward": torch.randn(
+ batch_size,
+ ),
+ "reject_reward": torch.randn(
+ batch_size,
+ ),
}
- loss_output = loss(**loss_input)
+ loss(**loss_input)
if __name__ == "__main__":
@@ -218,4 +242,4 @@ def test_loss(batch_size: int, seq_len: int, num_labels: int):
test_models(models_maker=lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()), batch_size=8, seq_len=128)
- test_loss(batch_size=8, seq_len=128, num_labels=100)
\ No newline at end of file
+ test_loss(batch_size=8, seq_len=128, num_labels=100)
diff --git a/applications/Chat/tests/test_train.sh b/applications/Chat/tests/test_train.sh
index c5127c188612..68fca7fbf8c0 100755
--- a/applications/Chat/tests/test_train.sh
+++ b/applications/Chat/tests/test_train.sh
@@ -24,8 +24,8 @@ if [ -z "$SFT_DATASET" ]; then
exit 1
fi
-if [ -z "$PROMPT_PATH" ]; then
- echo "Please set \$PROMPT_PATH to the path to prompts csv."
+if [ -z "$PROMPT_DATASET" ]; then
+ echo "Please set \$PROMPT_DATASET to the path to prompts csv."
exit 1
fi
@@ -41,6 +41,7 @@ MODELS_DIR=$BASE_DIR/examples/models_config
MODELS=('gpt2' 'bloom' 'opt' 'llama')
STRATEGIES=('ddp' 'colossalai_gemini' 'colossalai_zero2')
+
export OMP_NUM_THREADS=8
# install requirements
@@ -74,6 +75,7 @@ echo "[Test]: testing sft ..."
# FIXME: This is a hack to skip tests that are not working
# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
# - llama-*: These tests can be passed locally, skipped for long execution time
+# - *-gemini: Gemini plugin does not support `from_pretrained` yet
SKIPPED_TESTS=(
"gpt2-ddp"
"llama-ddp"
@@ -82,7 +84,7 @@ SKIPPED_TESTS=(
)
GRAD_CKPTS=('' '--grad_checkpoint')
-for lora_rank in '0' '4'; do
+for lora_rank in '0'; do
for model in ${MODELS[@]}; do
strategies=($(shuf -e "${STRATEGIES[@]}"))
for strategy in ${strategies[@]}; do
@@ -105,7 +107,7 @@ for lora_rank in '0' '4'; do
$pretrain_model --tokenizer $MODELS_DIR/$model \
--model $model --strategy $strategy --lora_rank $lora_rank $grad_ckpt \
--dataset $SFT_DATASET --max_datasets_size 8 \
- --max_epochs 1 --batch_size 1 --accumulation_steps 1 \
+ --max_epochs 1 --batch_size 1 --accumulation_steps 1 --lr 1e-8 \
--save_path $EXAMPLES_DIR/rlhf_models/sft_ckpt_${model}_${lora_rank}
passed=$?
if [ $passed -eq 0 ]; then
@@ -125,6 +127,7 @@ echo "[Test]: testing reward model ..."
# FIXME: This is a hack to skip tests that are not working
# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
# - llama-*: These tests can be passed locally, skipped for long execution time
+# - *-gemini: Gemini plugin does not support `from_pretrained` yet
SKIPPED_TESTS=(
"gpt2-ddp"
"llama-ddp"
@@ -134,7 +137,7 @@ SKIPPED_TESTS=(
LOSS_FNS=('log_sig' 'log_exp')
DATASETS=('Anthropic/hh-rlhf' 'Dahoas/rm-static')
-for lora_rank in '0' '4'; do
+for lora_rank in '0'; do
for model in ${MODELS[@]}; do
strategies=($(shuf -e "${STRATEGIES[@]}"))
for strategy in ${strategies[@]}; do
@@ -157,8 +160,9 @@ for lora_rank in '0' '4'; do
echo "[Test]: $model-$strategy-$lora_rank, attempt $i"
torchrun --standalone --nproc_per_node=4 $EXAMPLES_DIR/train_reward_model.py \
$pretrain_model --tokenizer $MODELS_DIR/$model \
- --model $model --strategy $strategy --lora_rank $lora_rank --loss_fn $loss_fn \
- --dataset $dataset --subset $subset --test True --batch_size 1 \
+ --dataset $dataset --subset $subset --max_datasets_size 8 \
+ --model $model --strategy $strategy --lora_rank $lora_rank \
+ --loss_fn $loss_fn --batch_size 1 --lr 1e-8 \
--save_path $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt
passed=$?
if [ $passed -eq 0 ]; then
@@ -178,6 +182,7 @@ echo "[Test]: testing RLHF ..."
# FIXME: This is a hack to skip tests that are not working
# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
# - llama-*: These tests can be passed locally, skipped for long execution time
+# - *-gemini: Gemini plugin does not support `from_pretrained` yet
SKIPPED_TESTS=(
"gpt2-ddp"
"llama-ddp"
@@ -186,7 +191,7 @@ SKIPPED_TESTS=(
)
for model in ${MODELS[@]}; do
- for lora_rank in '0' '4'; do
+ for lora_rank in '0'; do
strategies=($(shuf -e "${STRATEGIES[@]}"))
for strategy in ${strategies[@]}; do
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy-$lora_rank " ]]; then
@@ -204,13 +209,13 @@ for model in ${MODELS[@]}; do
for i in $(seq $NUM_RETRY); do
echo "[Test]: $model-$strategy-$lora_rank, attempt $i"
torchrun --standalone --nproc_per_node=4 $EXAMPLES_DIR/train_prompts.py \
- --prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
+ --prompt_dataset $PROMPT_DATASET --pretrain_dataset $PRETRAIN_DATASET --max_datasets_size 32 \
--strategy $strategy --model $model --tokenizer $MODELS_DIR/$model \
- --num_episodes 1 --num_collect_steps 1 --num_update_steps 1 \
+ --num_episodes 1 --num_collect_steps 1 --num_update_steps 1 --lr 1e-8 \
--experience_batch_size 2 --train_batch_size 1 --lora_rank $lora_rank \
--pretrain $EXAMPLES_DIR/rlhf_models/sft_ckpt_${model}_${lora_rank} \
$rm_pretrain_model --rm_path $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt \
- --save_path $EXAMPLES_DIR/rlhf_models/actor_checkpoint_prompts.pt
+ --save_path $EXAMPLES_DIR/rlhf_models/actor_checkpoint_prompts
passed=$?
if [ $passed -eq 0 ]; then
break
@@ -225,4 +230,4 @@ for model in ${MODELS[@]}; do
rm $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt
done
done
-rm $EXAMPLES_DIR/rlhf_models/actor_checkpoint_prompts.pt
+rm -rf $EXAMPLES_DIR/rlhf_models/actor_checkpoint_prompts
diff --git a/applications/Colossal-LLaMA-2/README.md b/applications/Colossal-LLaMA-2/README.md
new file mode 100644
index 000000000000..34967c04360c
--- /dev/null
+++ b/applications/Colossal-LLaMA-2/README.md
@@ -0,0 +1,390 @@
+
+
+
+
+
+
+## Table of Contents
+- [News](#news)
+- [Colossal-LLaMA-2-7B](#colossal-llama-2-7b)
+ - [Performance Evaluation](#performance-evaluation)
+ - [Examples](#examples)
+ - [Training Logs](#training-logs)
+ - [Import from Transformers](#import-from-transformers)
+- [Usage](#usage)
+ - [Install](#install)
+ - [How to run](#how-to-run)
+- [Technical Insight](#technical-insights)
+ - [Data](#data)
+ - [Tokenizer](#tokenizer)
+ - [Training Strategy](#training-strategy)
+ - [Bridging Any Domain-specific Large Models](#bridging-any-domain-specific-large-models)
+- [Citations](#citations)
+
+## News
+* [2023/09] [One Half-Day of Training Using a Few Hundred Dollars Yields Similar Results to Mainstream Large Models, Open-Source and Commercial-Free Domain-Specific Llm Solution](https://www.hpc-ai.tech/blog/one-half-day-of-training-using-a-few-hundred-dollars-yields-similar-results-to-mainstream-large-models-open-source-and-commercial-free-domain-specific-llm-solution)
+[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Colossal-LLaMA-2)
+[[blog]](https://www.hpc-ai.tech/blog/one-half-day-of-training-using-a-few-hundred-dollars-yields-similar-results-to-mainstream-large-models-open-source-and-commercial-free-domain-specific-llm-solution)
+[[model weights]](https://huggingface.co/hpcai-tech/Colossal-LLaMA-2-7b-base)
+
+## Colossal-LLaMA-2-7B
+The [Colossal-AI](https://github.com/hpcaitech/ColossalAI) team has introduced the open-source model **Colossal-LLaMA-2-7B-base**. This model, a derivation of LLaMA-2, has undergone continual pre-training involving approximately 8.5 billion tokens over a duration of 15 hours with 64 A800 GPUs. At a cost of **less than $1,000**, you can achieve results **similar to those that cost millions of dollars to pretrain from scratch**. It is licensed under the LLaMA-2 license and [Apache 2.0 License](https://github.com/hpcaitech/ColossalAI/blob/main/LICENSE) **without any additional commercial use restrictions**. This solution can also be used to build models of specific domain knowledge or tasks.
+
+Colossal-LLaMA-2-7B-base is designed to accommodate both the Chinese and English languages, featuring an expansive context window spanning 4096 tokens. Remarkably, it has exhibited exceptional performance when benchmarked against models of equivalent scale in standard Chinese and English evaluation metrics, including C-Eval and MMLU, among others.
+
+❗️**Important notice**:
+* All training data used for this project is collected from well-known public dataset.
+* We do not use any testing data from the evaluation benchmarks for training.
+
+### Performance Evaluation
+We conducted comprehensive evaluation on 4 dataset and compare our Colossal-Llama-2-7b-base model with various models.
+
+* We use 5-shot for MMLU and calculate scores based on the logits of first predicted token.
+* We use 5-shot for CMMLU and calculate scores based on the logits of first predicted token.
+* We use 5-shot for AGIEval and only calculate scores for 4-choice questions using a combination metric of exact match and the logits of first predicted token. If any of the exact match or logits of first predicted token is correct, the model will get the score.
+* We use 0-shot for GAOKAO-Bench and only calculate scores for 4-choice questions based on the logits of first predicted token.
+The generation config for all dataset is greedy search.
+* We also provided CEval scores from its lastest leaderboard or the official repository of the model.
+
+| | Backbone | Tokens Consumed | | MMLU | CMMLU | AGIEval | GAOKAO | CEval |
+| :----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :-----: | :----: | :----: | :------------------------------: |
+| | | - | | 5-shot | 5-shot | 5-shot | 0-shot | 5-shot |
+| Baichuan-7B | - | 1.2T | | 42.32 (42.30) | 44.53 (44.02) | 38.72 | 36.74 | 42.80 |
+| Baichuan-13B-Base | - | 1.4T | | 50.51 (51.60) | 55.73 (55.30) | 47.20 | 51.41 | 53.60 |
+| Baichuan2-7B-Base | - | 2.6T | | 46.97 (54.16) | 57.67 (57.07) | 45.76 | 52.60 | 54.00 |
+| Baichuan2-13B-Base | - | 2.6T | | 54.84 (59.17) | 62.62 (61.97) | 52.08 | 58.25 | 58.10 |
+| ChatGLM-6B | - | 1.0T | | 39.67 (40.63) | 41.17 (-) | 40.10 | 36.53 | 38.90 |
+| ChatGLM2-6B | - | 1.4T | | 44.74 (45.46) | 49.40 (-) | 46.36 | 45.49 | 51.70 |
+| InternLM-7B | - | 1.6T | | 46.70 (51.00) | 52.00 (-) | 44.77 | 61.64 | 52.80 |
+| Qwen-7B (original) | - | 2.2T | | 54.29 (56.70) | 56.03 (58.80) | 52.47 | 56.42 | 59.60 |
+| | | | | | | | | |
+| Llama-2-7B | - | 2.0T | | 44.47 (45.30) | 32.97 (-) | 32.60 | 25.46 | - |
+| Linly-AI/Chinese-LLaMA-2-7B-hf | Llama-2-7B | 1.0T | | 37.43 | 29.92 | 32.00 | 27.57 | - |
+| wenge-research/yayi-7b-llama2 | Llama-2-7B | - | | 38.56 | 31.52 | 30.99 | 25.95 | - |
+| ziqingyang/chinese-llama-2-7b | Llama-2-7B | - | | 33.86 | 34.69 | 34.52 | 25.18 | 34.2 |
+| TigerResearch/tigerbot-7b-base | Llama-2-7B | 0.3T | | 43.73 | 42.04 | 37.64 | 30.61 | - |
+| LinkSoul/Chinese-Llama-2-7b | Llama-2-7B | - | | 48.41 | 38.31 | 38.45 | 27.72 | - |
+| FlagAlpha/Atom-7B | Llama-2-7B | 0.1T | | 49.96 | 41.10 | 39.83 | 33.00 | - |
+| IDEA-CCNL/Ziya-LLaMA-13B-v1.1 | Llama-13B | 0.11T | | 50.25 | 40.99 | 40.04 | 30.54 | - |
+| | | | | | | | | |
+| **Colossal-LLaMA-2-7b-base** | Llama-2-7B | **0.0085T** | | 53.06 | 49.89 | 51.48 | 58.82 | 50.2 |
+
+> The score in parentheses corresponds to the scores in the official repository of the model.
+>
+> We use zero-shot for ChatGLM models.
+>
+> Qwen-7B is now inaccessible in Hugging Face, we are using the latest version of it before it was made inaccessible. Only for dataset MMLU, the prompt would be "xxx Answer:"(remove the space after ":") and we calculate the logits over " A", " B", " C" and " D" for Qwen-7B. Qwen-7B tends to be much more deterministic than other models. For example, the logits over " A" can be `-inf` and softmax would be exact `0`.
+>
+> For other models and other dataset, we calculate logits over "A", "B", "C" and "D".
+
+❗️ More details of the evaluation methods and reproduction of the results, please refer to [ColossalEval](https://github.com/hpcaitech/ColossalAI/tree/main/applications/ColossalEval).
+
+### Examples
+| Question Type | Question | Colossal-LLaMA-2-7b-base |
+| :------: | :-----------------------: | :----------- |
+| Knowledge Quiz
(Tourism) | 请问武汉的最佳的旅行季节是什么时候? | 3-4月份到武汉赏樱,10月份到武汉看黄陂枫树。武汉的最佳的旅行季节是春秋,3-4月份到武汉赏樱,10月份到武汉看黄陂枫树。武汉,简称“汉”别称江城,是湖北省省会,中部六省唯一的副省级市,特大城市,中国中部地区的中心城市,全国重要的工业基地、科教基地和综合交通枢纽。 |
+| Knowledge Quiz
(Computer Science) | 在计算机中,什么是RAM?与ROM有什么区别? | RAM是随机存取内存,可以随时读写,但是断电之后会丢失数据。存储数据比较快,但是需要不断的进行通电供电。ROM是只读内存,只能读取数据,不能写入数据。断电之后,存储的数据也不会丢失,但是需要不断进行通电。两者的价格和性能上会有区别,而且在使用的时候,RAM需要不断的供电,而ROM断电数据不丢失。 |
+| Knowledge Quiz
(Literature) | 古代文学与现代文学有哪些区别和共通之处? | 古代文学与现代文学有三点区别,现代文学更加通俗易懂,具有可读性,而古代文学相对来说更加注重形式,缺乏可读性。现代文学和古代文学,同属于中华文学大花园中的一枝花蕾。但是它们由于历史的原因,有着不同的特点,具体表现在,现代文学有更丰富的描写方法,更强烈的主观色彩和更突出的抒情色彩,而古代文学的主要特征是在创作风格上的多样化。 |
+| Emotional Analysis | 请判断这个句子的情绪是否正面:我迫不及待想见到她!!!”。 | 这一句是情绪是正向的。是在期待见面的喜悦情绪。 |
+| Translation | 将以下句子翻译成英语。 我喜欢看电影和读书。 | I like watching movies and reading books. |
+| Logical Reasoning | All books have the same weight, 10 books weigh 5kg, what is the weight of 2 books? | To find the weight of 2 books, we can use the following steps:
1. Determine the weight of 1 book.
2. Multiply the weight of 1 book by the number of books.
Step 1: Determine the weight of 1 book.
Since 10 books weigh 5kg, we can divide 5kg by 10 to find the weight of 1 book.
5kg ÷ 10 = 0.5kg
Step 2: Multiply the weight of 1 book by the number of books.
To find the weight of 2 books, we multiply the weight of 1 book by 2.
0.5kg × 2 = 1kg
So, the weight of 2 books is 1kg. |
+| Information Extraction | The company's annual conference, featuring keynote speakers and exclusive product launches, will be held at the Los Angeles Convention Center from October 20th to October 23rd, 2021. Extract the date mentioned in the above sentence. | The date mentioned in the sentence is October 20th to October 23rd, 2021. |
+| Error Correction | Please correct the syntax error and restructure the sentence: "Went to the store me and my friend." | The given sentence "Went to the store me and my friend." is incorrect due to a syntax error and it lacks a subject. The corrected version with proper grammar and sentence structure would be: "My friend and I went to the store." |
+
+❗️ More examples of question answering, please refer to [Colossal-LLaMA-2-7B-base Examples](docs/example.md).
+
+### Training Logs
+We also recorded the training logs for the experiment
+
+
+
+
+
+
+
+
+
+### Import from Transformers (Inference)
+To load Colossal-LLaMA-2-7B-base model using Transformers, use the following code:
+```Python
+from transformers import AutoModelForCausalLM, AutoTokenizer
+model = AutoModelForCausalLM.from_pretrained("hpcai-tech/Colossal-LLaMA-2-7b-base", device_map="auto", trust_remote_code=True)
+tokenizer = AutoTokenizer.from_pretrained("hpcai-tech/Colossal-LLaMA-2-7b-base", trust_remote_code=True)
+input = "离离原上草,"
+inputs = tokenizer(input, return_tensors='pt')
+inputs = inputs.to('cuda:0')
+pred = model.generate(**inputs,
+ max_new_tokens=256,
+ do_sample=True,
+ top_k=50,
+ top_p=0.95,
+ num_return_sequences=1)
+print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True)[len(input):])
+```
+
+You can also download model weights from [🤗HuggingFace](https://huggingface.co/hpcai-tech/Colossal-LLaMA-2-7b-base).
+
+## Usage
+### Install
+
+#### 0. Pre-requisite
+1. This experiment was performed on 8 computing nodes with 64 A800 GPUs in total for LLaMA-2-7B (**about 1000 USD cost**). The nodes are connected with RDMA and GPUs within one node are fully connected with NVLink. The script was tested with CUDA 11.7, CUDA version requires 11.7 or higher. You can also complete it in about 5 days on a 8*A100/A800 server.
+
+2. PyTorch. The PyTorch version should be less than 2.0.0 and greater than 1.12.1.
+
+
+#### 1. Install required packages
+```
+cd Colossal-LLaMA-2
+pip install -r requirements.txt
+```
+#### 2. Install `xentropy`, `layer_norm` and `rotary`
+```bash
+git clone git@github.com:Dao-AILab/flash-attention.git
+# At the root folder
+cd csrc/xentropy && pip install .
+# At the root folder
+cd csrc/layer_norm && pip install .
+# At the root folder
+cd csrc/rotary && pip install .
+```
+
+### How to run
+
+#### 1. Init Tokenizer Preparation
+Initialize new tokenizer with additional Chinese tokens. Additional Chinese tokens are stored in `jsonl` format as follows:
+```json
+{"piece": "你好"}
+{"piece": "人工智能"}
+```
+Command to initialize new tokenizer:
+```bash
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION='python'
+python colossal_llama2/tokenizer/init_tokenizer.py \
+ --source_tokenizer_dir "" \
+ --target_tokenizer_dir "" \
+ --expand_tokens_file ".jsonl"
+```
+Here is details about CLI arguments:
+* Source tokenizer directory: `--source_tokenizer_dir`. Directory to the source tokenizer. It should at least contain three files: `special_tokens_map.json`, `tokenizer.model` and `tokenizer_config.json`.
+* Target tokenizer directory: `--target_tokenizer_dir`. Directory to the target tokenizer.
+* Tokens to be added: `--expand_tokens_file`. Additional tokens to be added to the tokenizer.
+
+#### 2. Init Model Preparation
+Initialize the new model checkpoint by calculating the mean values from the original model checkpoint.
+Command to initialize new model checkpoint:
+```bash
+python colossal_llama2/model/init_model.py \
+ --source_model_and_tokenizer_path "" \
+ --target_tokenizer_path "" \
+ --target_model_path ""
+```
+"" can be the same as "".
+
+Here is details about CLI arguments:
+* Source model and tokenizer path: `--source_model_and_tokenizer_path`. Source folder contains both model and tokenizer, for example, LLaMA-2 model in Hugging Face format.
+* Target tokenizer path: `--target_tokenizer_path`. Path to the new tokenizer folder generated from previous step.
+* Target model path: `--target_model_path`. Path to save the new model in Hugging Face format.
+
+❗️**Important**: Once you initialize the new model checkpoint, copy your new tokenizer files (`special_tokens_map.json`, `tokenizer.model` and `tokenizer_config.json`) to your new model folder.
+
+#### 3. Data Preparation
+Raw data should be formatted as `jsonl` format. Each data point should have the following fields:
+* `source` (str, compulsory): This part is ignored when calculating loss. Default can be empty.
+* `target` (str, compulsory): Loss will be calculated.
+* `category` (str, compulsory): Tags for each data point.
+
+Examples:
+```JSON
+{"source": "", "target": "Lionel Andrés Messi(Spanish pronunciation: [ljoˈnel anˈdɾes ˈmesi] (i); born 24 June 1987), also known as Leo Messi, is an Argentine professional footballer who plays as a forward for and captains both Major League Soccer club Inter Miami and the Argentina national team.", "category": "sports"}
+{"source": "猜谜语:一身卷卷细毛,吃的青青野草,过了数九寒冬,无私献出白毛。(打一动物)", "target": "白羊", "category": "riddle"}
+```
+You are allowed to customize the category tags or use `unknown` to define the category.
+
+Command to convert jsonl dataset to arrow format:
+```
+python prepare_pretrain_dataset.py \
+ --data_input_dirs ",," \
+ --tokenizer_dir "" \
+ --data_cache_dir "jsonl_to_arrow_cache" \
+ --data_jsonl_output_dir "spliced_tokenized_output_jsonl" \
+ --data_arrow_output_dir "spliced_tokenized_output_arrow" \
+ --max_length 4096 \
+ --num_spliced_dataset_bins 10
+```
+Here is details about CLI arguments:
+* Source data directory: `data_input_dirs`. Each `` can have multiple file in `jsonl` format.
+* Tokenzier directory: `tokenizer_dir`. Path to the tokenizer in Hugging Face format.
+* Data cache directory: `data_cache_dir`. Directory to store Hugging Face data cache. Default case will create `cache` folder locally.
+* Output directory for jsonl format: `data_jsonl_output_dir`. Output directory to store converted dataset in jsonl format.
+* Output directory for arrow format: `data_arrow_output_dir`. Output directory to store converted dataset in arrow format, which can be used for training directly.
+* Max length: `max_length`. Max length of spliced samples. Default value is 4096.
+* Number of bins for each category: `num_spliced_dataset_bins`. Number of bins for each category, used for bucket-based training.
+
+#### 4. Command Line Arguments for Training
+You can use `colossalai run` to launch multi-nodes training:
+```bash
+colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \
+train.py --OTHER_CONFIGURATIONS
+```
+Here is a sample hostfile:
+```bash
+hostname1
+hostname2
+hostname3
+hostname4
+```
+Make sure master node can access all nodes (including itself) by ssh without password.
+
+Here is details about CLI arguments:
+* Pre-trained model path: `--pretrained`. Path to the pre-trained model in Hugging Face format.
+* Dataset path: `--dataset`. Path to the pre-tokenized dataset.
+* Booster plugin: `--plugin`. `gemini`, `gemini_auto`, `zero2`,`zero2_cpu` and `3d` are supported.For more details, please refer to [Booster plugins](https://colossalai.org/docs/basics/booster_plugins/).
+* Intermediate checkpoint to load: `--load_checkpoint`. Path to the intermediate checkpoint. Saved checkpoint contains the states for `lr_scheduler`, `optimizer`,`running_states.json` and `modelling`. If `load_checkpoint` points to the `modelling` folder, only the model weights will be loaded without any other states to support multi-stage training.
+* Save interval: `--save_interval`. The interval (steps) of saving checkpoints. The default value is 1000.
+* Checkpoint directory: `--save_dir`. The directoty path to save checkpoint and intermediate states. Intermediate states include `lr_scheduler`, `optimizer`,`running_states.json` and `modelling`.
+* Tensorboard directory: `--tensorboard_dir`. The path to save tensorboard logs.
+* Configuration file: `--config_file`. The path to save the configuration file.
+* Number of epochs: `--num_epochs`. Number of training epochs. The default value is 1.
+* Micro batch size: `--micro_batch_size`. Batch size per GPU. The default value is 1.
+* Learning rate: `--lr`. The default value is 3e-4.
+* Max length: `--max_length`. Max context length. The default value is 4096.
+* Mixed precision: `--mixed_precision`. The default value is "fp16". "fp16" and "bf16" are supported.
+* Gradient clipping: `--gradient_clipping`. The default value is 1.0.
+* Weight decay: `-w`, `--weight_decay`. The default value is 0.1.
+* Warmup steps: `-s`, `--warmup_steps`. The default value is calcuated by 0.025 warmup ratio.
+* Gradient checkpointing: `--use_grad_checkpoint`. The default value is `False`. This saves memory at the cost of speed. You'd better enable this option when training with a large batch size.
+* Flash attention: `--use_flash_attn`. If you want to use flash attention, you must install `flash-attn` and related packages. The default value is `False`. This is helpful to accelerate training while saving memory. We recommend you always use flash attention.
+* Freeze non-embedding parameters: `--freeze_non_embeds_params`. Freeze non-embedding parameters. It can be helpful to align embeddings after extending vocabulary size.
+* Tensor parallelism size: `--tp`. TP size for 3d Parallelism. The default value is 1.
+* Zero stage: `--zero`. Zero stage for 3d Parallelism. The default value is 1.
+
+#### 5. Running Command
+An [example bash](train.example.sh) is also provided for the experiment. Here is the steps to run the experiment:
+* Create your own hostfile: `cp hostfile.example hostfile`.
+* Create your own bash: `cp train.example.sh train.sh`.
+* Add your real host ip or host name into the `hostfile`.
+* Update global variables and parameters in your `train.sh`.
+* Run the experiment by `bash train.sh`
+
+Here is the details about global variables for each experiment:
+* `PROJECT_NAME`: Project name for each experiment.
+* `PARENT_SAVE_DIR`: Parent folder to save model checkpoint.
+* `PARENT_TENSORBOARD_DIR`: Parent folder to save tensorboard logs.
+* `PARENT_CONFIG_FILE`: Parent folder to save configuration for each experiment.
+* `PRETRAINED_MODEL_PATH`: Path to the local pre-trained model checkpoint.
+* `dataset`: Paths to all prepared data. Typically, it's a list of subfolders within the output path of prepare data, `--data_arrow_output_dir`, and if there are multiple subfolders, please list them all. e.g.,
+```python
+declare -a dataset=(
+ "/part-00000"
+ "/part-00001"
+ "/part-00000"
+)
+```
+## Technical Insights
+In order to enhance LLaMA-2's capabilities for understanding and generating Chinese content, The [Colossal-AI](https://github.com/hpcaitech/ColossalAI) team proposes the continuation of pre-training the LLaMA-2 model using both Chinese and English corpora. The overall pipeline can be described as follows:
+
+
+
+
+
+### Data
+Large language models such as LLaMA-2 have undergone training using a heterogeneous blend of high-quality datasets, yielding promising outcomes. Enhancing LLaMA-2's performance for the Chinese corpus, while preserving its proficiency in English, critically hinges on two pivotal factors: the composition of the dataset, which encompasses both English and Chinese content, and the quality of each constituent dataset.
+
+The following figure shows the data processing pipeline conducted for Colossal-LLaMA-2.
+
+
+
+
+❗️**Important**: We will open-source our data-processing toolkit soon, stay tuned!
+
+### Tokenizer
+The original LLaMA-2 vacabulary comprises fewer than a thousand Chinese characters, thus proves inadequate for encoding comprehensive Chinese texts effectively. Secondly, the utilization of byte tokens presents a challenge for transformer encoders to capture the semantic nuances of Chinese characters.
+
+To address the above issues, we extend LLaMA-2 vocabulary from 32,000 to 69,104. To adapt the LLaMA-2 model for use with the Colossal-LLaMA-2 tokenizer, we initialize the new word embeddings by calculating the mean values from the original LLaMA-2 embeddings and subsequently append these new rows to the end of the original embedding matrices.
+
+Advantages of extending vocabulary size:
+* Improve the compression rate of string sequence encoding.
+* Enhance the integrity of information.
+* Enable encoded sequences to contain more valuable information, thereby theoretically enhancing the ability for chapter-level encoding.
+
+Advantages of large vocabulary size under low-resource settings:
+* The presence of numerous unused tokens can be attributed to the limited training dataset, where an excessive number of tokens might not have been effectively learned.
+* Excessive vocabulary expansion leads to an increase in embedding-related parameters, resulting in higher memory usage, which, in turn, affects the efficiency of the training process.
+
+To balance both sides, we finally construct our vocabulary with size 69,104. The following table below presents a comparison of various models at the 7B level.
+
+| Model | Vocabulary Size | Compression Rate | Average Length of Samples (token-level) |
+| :-----------: | :---------: | :----: | :----: |
+| Colossal-LLaMA-2 | 69104 | 0.659 | 73.682 |
+| LLaMA-2-7B | 32000 | 1.205 | 134.689 |
+| Atom-7B | 65000 | 0.634 | 70.915 |
+| Baichuan-7B | 64000 | 0.678 | 75.857 |
+| Baichuan2-7B-base | 125696 | 0.570 | 63.761 |
+| Chatglm2-6B | 64789 | 0.645 | 72.178 |
+| InternLM-7B | 103168 | 0.566 | 63.349 |
+| Qwen-7B | 151643 | 0.578 | 64.703 |
+| Tigerbot-7B-base | 60515 | 0.630 | 70.515 |
+| Yayi-7B-llama2 | 32005 | 1.214 | 135.689 |
+| Chinese-llama-2-7b | 55296 | 0.668 | 74.690 |
+| Chinese-Falcon-7B | 90046 | 0.669 | 74.858 |
+| LinkSoul-Chinese-Llama-2-7b | 40076 | 0.958 | 107.089 |
+| Ziya-LLaMA-13B-v1.1 | 39410 | 0.958 | 107.074 |
+
+
+### Training Strategy
+#### Multi-stage Training
+In order to enhance the model's performance and harness the full potential of the original LLaMA-2, we have developed a multi-stage training strategy. This strategy is designed to systematically unlock the model's capabilities over a series of stages.
+
+Therefore, we have divided the training process into three stages:
+* Large-scale pre-training stage (Conducted by LLaMA-2): This initial stage is aimed at establishing the model's foundational capabilities from the ground up. It necessitates the use of a substantial dataset comprising no less than 1 trillion tokens.
+* Chinese knowledge injection stage: In this stage, we introduce Chinese knowledge into the model. It requires access to a high-quality dataset rich in comprehensive knowledge relevant to the Chinese language.
+* Knowledge replay stage: Knowledge is replayed through a question-answering (QA) mechanism, encompassing both the Chinese and English domains.
+
+Following the completion of this multi-stage training process, the model exhibits notable improvements in performance across both English and Chinese benchmarks.
+
+The following figure illustrates the three stages for training Colossal-LLaMA-2.
+
+
+
+
+
+#### Bucket-based Training
+Our experiments have revealed that the distributions within the training dataset, as well as the arrangement of various topic-related data points, significantly impact the overall performance of the model, particularly in the context of continual pre-training of LLaMA-2.
+
+In an effort to achieve a more balanced distribution and exert control over the dataset's ordering, we have adopted a method where we divide each sub-dataset into discrete bins. These bins are then combined to construct individual data buckets, with one bin contributed by each sub-dataset.
+
+### Bridging Any Domain-specific Large Models
+Applying the above process to perform knowledge transfer in any field allows for the cost-effective construction of lightweight domain-specific foundational large models.
+
+
+
+
+
+## Citations
+```bibtex
+@article{bian2021colossal,
+ title={Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training},
+ author={Bian, Zhengda and Liu, Hongxin and Wang, Boxiang and Huang, Haichen and Li, Yongbin and Wang, Chuanrui and Cui, Fan and You, Yang},
+ journal={arXiv preprint arXiv:2110.14883},
+ year={2021}
+}
+```
+```bibtex
+@misc{touvron2023llama,
+ title={Llama 2: Open Foundation and Fine-Tuned Chat Models},
+ author={Hugo Touvron and Louis Martin and Kevin Stone and Peter Albert and Amjad Almahairi and Yasmine Babaei and Nikolay Bashlykov and Soumya Batra and Prajjwal Bhargava and Shruti Bhosale and Dan Bikel and Lukas Blecher and Cristian Canton Ferrer and Moya Chen and Guillem Cucurull and David Esiobu and Jude Fernandes and Jeremy Fu and Wenyin Fu and Brian Fuller and Cynthia Gao and Vedanuj Goswami and Naman Goyal and Anthony Hartshorn and Saghar Hosseini and Rui Hou and Hakan Inan and Marcin Kardas and Viktor Kerkez and Madian Khabsa and Isabel Kloumann and Artem Korenev and Punit Singh Koura and Marie-Anne Lachaux and Thibaut Lavril and Jenya Lee and Diana Liskovich and Yinghai Lu and Yuning Mao and Xavier Martinet and Todor Mihaylov and Pushkar Mishra and Igor Molybog and Yixin Nie and Andrew Poulton and Jeremy Reizenstein and Rashi Rungta and Kalyan Saladi and Alan Schelten and Ruan Silva and Eric Michael Smith and Ranjan Subramanian and Xiaoqing Ellen Tan and Binh Tang and Ross Taylor and Adina Williams and Jian Xiang Kuan and Puxin Xu and Zheng Yan and Iliyan Zarov and Yuchen Zhang and Angela Fan and Melanie Kambadur and Sharan Narang and Aurelien Rodriguez and Robert Stojnic and Sergey Edunov and Thomas Scialom},
+ year={2023},
+ eprint={2307.09288},
+ archivePrefix={arXiv},
+ primaryClass={cs.CL}
+}
+```
+```bibtex
+@article{dao2023flashattention2,
+ title={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning},
+ author={Dao, Tri},
+ year={2023}
+}
+}
+```
diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/__init__.py b/applications/Colossal-LLaMA-2/colossal_llama2/__init__.py
new file mode 100644
index 000000000000..56fafa58b3f4
--- /dev/null
+++ b/applications/Colossal-LLaMA-2/colossal_llama2/__init__.py
@@ -0,0 +1,2 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/dataset/__init__.py b/applications/Colossal-LLaMA-2/colossal_llama2/dataset/__init__.py
new file mode 100644
index 000000000000..56fafa58b3f4
--- /dev/null
+++ b/applications/Colossal-LLaMA-2/colossal_llama2/dataset/__init__.py
@@ -0,0 +1,2 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py b/applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py
new file mode 100644
index 000000000000..a2cfb2ef6264
--- /dev/null
+++ b/applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py
@@ -0,0 +1,219 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+import numpy as np
+import os
+import random
+from dataclasses import dataclass
+from typing import Dict, List, Union, Sequence, Optional, Iterator, Callable
+
+import torch
+from datasets import dataset_dict, load_from_disk
+from datasets import Dataset as HFDataset
+from torch.distributed import ProcessGroup
+from torch.distributed.distributed_c10d import _get_default_group
+from torch.utils.data import ConcatDataset, Dataset, DataLoader, DistributedSampler
+from transformers.tokenization_utils import PreTrainedTokenizer
+import torch.nn.functional as F
+
+DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
+PathType = Union[str, os.PathLike]
+
+
+def load_tokenized_dataset(
+ dataset_paths: Union[PathType, List[PathType]], mode: str = "train"
+) -> Optional[DatasetType]:
+ """
+ Load pre-tokenized dataset.
+ Each instance of dataset is a dictionary with
+ `{'input_ids': List[int], 'labels': List[int], sequence: str}` format.
+ """
+ mode_map = {"train": "train", "dev": "validation", "test": "test"}
+ assert mode in tuple(mode_map), f"Unsupported mode {mode}, it must be in {tuple(mode_map)}"
+
+ if isinstance(dataset_paths, (str, os.PathLike)):
+ dataset_paths = [dataset_paths]
+
+ datasets = [] # `List[datasets.dataset_dict.Dataset]`
+ for ds_path in dataset_paths:
+ ds_path = os.path.abspath(ds_path)
+ assert os.path.exists(ds_path), f"Not existed file path {ds_path}"
+ ds_dict = load_from_disk(dataset_path=ds_path, keep_in_memory=False)
+ if isinstance(ds_dict, HFDataset):
+ datasets.append(ds_dict)
+ else:
+ if mode_map[mode] in ds_dict:
+ datasets.append(ds_dict[mode_map[mode]])
+ if len(datasets) == 0:
+ return None
+ if len(datasets) == 1:
+ return datasets.pop()
+ return ConcatDataset(datasets=datasets)
+
+
+@dataclass
+class DataCollatorForSupervisedDataset(object):
+ """
+ Collate instances for supervised dataset.
+ Each instance is a tokenized dictionary with fields
+ `input_ids`(List[int]), `labels`(List[int]) and `sequence`(str).
+ """
+
+ tokenizer: PreTrainedTokenizer
+ max_length: int = 4096
+ ignore_index: int = -100
+
+ def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
+ """
+
+ Args:
+ instances (`Sequence[Dict[str, List[int]]]`):
+ Mini-batch samples, each sample is stored in an individual dictionary.
+
+ Returns:
+ (`Dict[str, torch.Tensor]`): Contains the following `torch.Tensor`:
+ `input_ids`: `torch.Tensor` of shape (bsz, max_len);
+ `attention_mask`: `torch.BoolTensor` of shape (bsz, max_len);
+ `labels`: `torch.Tensor` of shape (bsz, max_len), which contains `IGNORE_INDEX`.
+ """
+ assert isinstance(self.tokenizer.pad_token_id, int) and self.tokenizer.pad_token_id >= 0, (
+ f"`{self.tokenizer.__class__.__name__}.pad_token_id` must be a valid non-negative integer index value, "
+ f"but now `{self.tokenizer.pad_token_id}`"
+ )
+
+ # `List[torch.Tensor]`
+ batch_input_ids = [
+ torch.LongTensor(instance["input_ids"][: self.max_length])
+ if len(instance["input_ids"]) > self.max_length
+ else torch.LongTensor(instance["input_ids"])
+ for instance in instances
+ ]
+ batch_labels = [
+ torch.LongTensor(instance["labels"][: self.max_length])
+ if len(instance["labels"]) > self.max_length
+ else torch.LongTensor(instance["labels"])
+ for instance in instances
+ ]
+
+ if self.tokenizer.padding_side == "right":
+ input_ids = torch.nn.utils.rnn.pad_sequence(
+ sequences=batch_input_ids,
+ batch_first=True,
+ padding_value=self.tokenizer.pad_token_id,
+ ) # (bsz, max_len)
+ labels = torch.nn.utils.rnn.pad_sequence(
+ sequences=batch_labels,
+ batch_first=True,
+ padding_value=self.ignore_index,
+ ) # (bsz, max_len)
+ # pad to max
+ to_pad = self.max_length - input_ids.size(1)
+ input_ids = F.pad(input_ids, (0, to_pad), value=self.tokenizer.pad_token_id)
+ labels = F.pad(labels, (0, to_pad), value=self.ignore_index)
+ elif self.tokenizer.padding_side == "left":
+ reversed_input_ids = [seq.flip(dims=(0,)) for seq in batch_input_ids]
+ reversed_input_ids = torch.nn.utils.rnn.pad_sequence(
+ sequences=reversed_input_ids,
+ batch_first=True,
+ padding_value=self.tokenizer.pad_token_id,
+ ) # (bsz, max_len)
+ input_ids = torch.flip(reversed_input_ids, dims=(1,)) # (bsz, max_len)
+ reversed_labels = [seq.flip(dims=(0,)) for seq in batch_labels]
+ reversed_labels = torch.nn.utils.rnn.pad_sequence(
+ sequences=reversed_labels,
+ batch_first=True,
+ padding_value=self.ignore_index,
+ ) # (bsz, max_len)
+ labels = torch.flip(reversed_labels, dims=(1,)) # (bsz, max_len)
+ else:
+ raise RuntimeError(
+ f"`{self.tokenizer.__class__.__name__}.padding_side` can only be `left` or `right`, "
+ f"but now `{self.tokenizer.padding_side}`"
+ )
+
+ attention_mask = input_ids.ne(self.tokenizer.pad_token_id) # `torch.BoolTensor`, (bsz, max_len)
+
+ return dict(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
+
+
+class StatefulDistributedSampler(DistributedSampler):
+ """
+ Stateful distributed sampler for multi-stage training.
+ """
+
+ def __init__(
+ self,
+ dataset: DatasetType,
+ num_replicas: Optional[int] = None,
+ rank: Optional[int] = None,
+ shuffle: bool = True,
+ seed: int = 0,
+ drop_last: bool = False,
+ ) -> None:
+ super().__init__(
+ dataset=dataset,
+ num_replicas=num_replicas,
+ rank=rank,
+ shuffle=shuffle,
+ seed=seed,
+ drop_last=drop_last,
+ )
+ self.start_index = 0
+
+ def __iter__(self) -> Iterator:
+ iterator = super().__iter__()
+ indices = list(iterator)
+ indices = indices[self.start_index :]
+ return iter(indices)
+
+ def __len__(self) -> int:
+ return self.num_samples - self.start_index
+
+ def set_start_index(self, start_index: int) -> None:
+ self.start_index = start_index
+
+
+def setup_distributed_dataloader(
+ dataset: DatasetType,
+ batch_size: int = 1,
+ shuffle: bool = False,
+ seed: int = 1024,
+ drop_last: bool = False,
+ pin_memory: bool = False,
+ num_workers: int = 0,
+ collate_fn: Callable[[Sequence[Dict[str, Union[str, List[int]]]]], Dict[str, torch.Tensor]] = None,
+ process_group: Optional[ProcessGroup] = None,
+ **kwargs,
+) -> DataLoader:
+ """
+ Setup dataloader for distributed training.
+ """
+ _kwargs = kwargs.copy()
+ process_group = process_group or _get_default_group()
+ sampler = StatefulDistributedSampler(
+ dataset=dataset,
+ num_replicas=process_group.size(),
+ rank=process_group.rank(),
+ shuffle=shuffle,
+ seed=seed,
+ drop_last=drop_last,
+ )
+
+ # Deterministic dataloader
+ def seed_worker(worker_id: int) -> None:
+ worker_seed = seed
+ np.random.seed(worker_seed)
+ torch.manual_seed(worker_seed)
+ random.seed(worker_seed)
+
+ return DataLoader(
+ dataset=dataset,
+ batch_size=batch_size,
+ sampler=sampler,
+ num_workers=num_workers,
+ collate_fn=collate_fn,
+ pin_memory=pin_memory,
+ drop_last=drop_last,
+ worker_init_fn=seed_worker,
+ **_kwargs,
+ )
diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/dataset/spliced_and_tokenized_dataset.py b/applications/Colossal-LLaMA-2/colossal_llama2/dataset/spliced_and_tokenized_dataset.py
new file mode 100644
index 000000000000..0c21f325ae62
--- /dev/null
+++ b/applications/Colossal-LLaMA-2/colossal_llama2/dataset/spliced_and_tokenized_dataset.py
@@ -0,0 +1,183 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+Splicing multiple pre-tokenized sequence data points
+"""
+
+import random
+import warnings
+from copy import deepcopy
+from datasets import dataset_dict
+from typing import Any, Callable, Dict, Iterable, List, Union, Tuple
+
+from torch.utils.data import ConcatDataset, Dataset, IterableDataset
+from transformers.models.llama.tokenization_llama import LlamaTokenizer
+from transformers.tokenization_utils import PreTrainedTokenizer
+
+IGNORE_INDEX = -100
+
+DSType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
+
+
+def supervised_tokenize(
+ data_point: Dict[str, str], tokenizer: LlamaTokenizer, ignore_index: int = None, max_length: int = 4096
+) -> Dict[str, Union[int, str, List[int]]]:
+ """
+ A tokenization function to tokenize an original pretraining data point as following:
+ {"source": "", "target": "Beijing, the capital of the People's Republic of China, ...", "category": "geography"}
+ """
+ assert tokenizer.add_bos_token is False and tokenizer.add_eos_token is False, (
+ "Initially set `tokenizer.add_bos_token` and `tokenizer.add_eos_token` to False, "
+ "add and manually later"
+ )
+ if ignore_index is None:
+ ignore_index = IGNORE_INDEX
+
+ source_text = data_point["source"] # `str`
+ target_text = data_point["target"] # `str`
+ is_null_source = len(source_text) == 0
+
+ source_text = tokenizer.bos_token + source_text
+ target_text += tokenizer.eos_token
+ sequence_text = source_text + target_text
+
+ tokenized = tokenizer([source_text, sequence_text])["input_ids"]
+ sequence_input_ids = tokenized[1]
+ sequence_labels = deepcopy(sequence_input_ids)
+
+ source_length = len(tokenized[0])
+ if not is_null_source:
+ sequence_labels[:source_length] = [ignore_index for _ in range(source_length)]
+
+ # sequence truncation.
+ if len(sequence_input_ids) > max_length:
+ sequence_input_ids = sequence_input_ids[:max_length]
+ sequence_labels = sequence_labels[:max_length]
+
+ return dict(
+ input_ids=sequence_input_ids,
+ labels=sequence_labels,
+ seq_length=len(sequence_input_ids),
+ seq_category=data_point["category"],
+ )
+
+
+class ClosedToConstantLengthSplicedDataset(IterableDataset):
+ """
+ Define an iterable dataset that returns a (close to) constant length data point spliced from multiple
+ original independent (pre-tokenized) data points.
+ """
+
+ def __init__(
+ self,
+ dataset: DSType,
+ tokenizer: PreTrainedTokenizer,
+ max_length: int = 4096,
+ num_packed_sequences: int = 8,
+ fetch_sequence_func: Callable[[Any], Tuple[List[int], List[int]]] = None,
+ input_ids_field: str = "input_ids",
+ labels_field: str = "labels",
+ infinite: bool = False,
+ shuffle: bool = True,
+ error_strict: bool = False,
+ ) -> None:
+ self.tokenizer = tokenizer
+ self.dataset = dataset
+ self.max_length = max_length
+ self.infinite = infinite
+ self.max_buffer_size = max_length * num_packed_sequences # e.g., 4096 * 16
+ self.shuffle = shuffle
+
+ # Callable[[Dict[str, Any]], Tuple[List[int], List[int]]],
+ # A function that fetch sequence input_ids and labels from the original data point
+ if fetch_sequence_func is None:
+ self.fetch_sequence_func = lambda data_point: (data_point[input_ids_field], data_point[labels_field])
+ else:
+ self.fetch_sequence_func = fetch_sequence_func
+ self.input_ids_field = input_ids_field
+ self.labels_field = labels_field
+
+ self.error_strict = error_strict
+ self.current_size = 0 # `int`, current packed data size.
+
+ def __len__(self) -> int:
+ return len(self.dataset)
+
+ def __iter__(self) -> Iterable[Dict[str, List[int]]]:
+ iterator = iter(self.dataset)
+ more_data_points = True
+ while more_data_points is True:
+ buffer, buffer_len = [], 0
+ while True:
+ # ending condition.
+ if buffer_len >= self.max_buffer_size:
+ break
+ try:
+ # `Tuple[List[int], List[int]]`
+ seq_input_ids, seq_labels = self.fetch_sequence_func(next(iterator))
+ buffer.append({self.input_ids_field: seq_input_ids, self.labels_field: seq_labels})
+ buffer_len += len(buffer[-1][self.input_ids_field])
+ except StopIteration:
+ if self.infinite is True:
+ iterator = iter(self.dataset)
+ warnings.warn("The dataset reached end and the iterator is reset to the start.")
+ else:
+ more_data_points = False
+ break
+ examples = [] # `List[Dict[str, List[int]]]`, save buffered spliced data points.
+ spliced_input_ids, spliced_labels = [], [] # `List[int]`, `List[int]`
+ for i, data_point in enumerate(buffer):
+ # TODO(2023-09-18) check errors for each unspliced tokenized data point
+ seq_input_ids = data_point[self.input_ids_field]
+ seq_labels = data_point[self.labels_field]
+ # Handle special case:
+ # If the length of an original data point (i.e., input_ids length of a data point before splicing)
+ # exceeds `max_length`, truncate it.
+ if len(seq_input_ids) > self.max_length:
+ truncated_seq_input_ids = seq_input_ids[: self.max_length]
+ truncated_label_ids = seq_labels[: self.max_length]
+ if set(truncated_label_ids) == {IGNORE_INDEX}:
+ if self.error_strict is True:
+ raise ValueError(
+ f"Find an out-of-bounds length({len(seq_input_ids)}) data point "
+ f"with all label values as {IGNORE_INDEX}."
+ )
+ else:
+ warnings.warn(f"Filter an error truncated data point (labels all {IGNORE_INDEX})")
+ continue # Skip the current error data point.
+ spliced_data_point = {
+ self.input_ids_field: truncated_seq_input_ids,
+ self.labels_field: truncated_label_ids,
+ }
+ examples.append(spliced_data_point)
+ warnings.warn("Find a data point to be truncated.")
+ continue
+
+ # Pre action judgment.
+ if len(spliced_input_ids) + len(seq_input_ids) > self.max_length:
+ spliced_data_point = {
+ self.input_ids_field: spliced_input_ids,
+ self.labels_field: spliced_labels,
+ } # `Dict[str, List[int]]`
+ # Update.
+ spliced_input_ids, spliced_labels = [], []
+ spliced_input_ids.extend(seq_input_ids)
+ spliced_labels.extend(seq_labels)
+ examples.append(spliced_data_point)
+ else:
+ spliced_input_ids.extend(seq_input_ids)
+ spliced_labels.extend(seq_labels)
+ # For residual spliced data point at the end of the data set
+ if self.infinite is False and more_data_points is False and len(spliced_input_ids) > 0:
+ examples.append(
+ {
+ self.input_ids_field: spliced_input_ids,
+ self.labels_field: spliced_labels
+ }
+ )
+ if self.shuffle:
+ random.shuffle(examples)
+ for spliced_data_point in examples:
+ # TODO(2023-09-18): check errors for each spliced tokenized data point.
+ self.current_size += 1
+ yield spliced_data_point
diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/model/init_model.py b/applications/Colossal-LLaMA-2/colossal_llama2/model/init_model.py
new file mode 100644
index 000000000000..67e487f43b08
--- /dev/null
+++ b/applications/Colossal-LLaMA-2/colossal_llama2/model/init_model.py
@@ -0,0 +1,111 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+"""
+Initialize new model with updated tokenizer by calculating the mean values from original model
+"""
+import argparse
+
+import numpy as np
+import torch
+from transformers import LlamaTokenizer, LlamaForCausalLM
+
+from colossalai.logging import get_dist_logger
+
+
+logger = get_dist_logger()
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--source_model_and_tokenizer_path",
+ type=str,
+ required=True,
+ default=None,
+ help="Source path of model & tokenizer",
+ )
+ parser.add_argument("--target_tokenizer_path", type=str, required=True, default=None, help="Target tokenizer path")
+ parser.add_argument("--target_model_path", type=str, required=True, default=None, help="Target model path")
+ args = parser.parse_args()
+
+ source_tokenizer = LlamaTokenizer.from_pretrained(args.source_model_and_tokenizer_path)
+ source_tokenizer.add_bos_token = False
+ source_tokenizer.add_eos_token = False
+ if source_tokenizer.pad_token is None:
+ source_tokenizer.pad_token = source_tokenizer.unk_token
+ source_vocab = source_tokenizer.get_vocab()
+
+ target_tokenizer = LlamaTokenizer.from_pretrained(args.target_tokenizer_path)
+ target_tokenizer.add_bos_token = False
+ target_tokenizer.add_eos_token = False
+ if target_tokenizer.pad_token is None:
+ target_tokenizer.pad_token = target_tokenizer.unk_token
+ target_vocab = target_tokenizer.get_vocab()
+ target_inverted_vocab = {v: k for k, v in target_vocab.items()}
+
+ assert len(target_vocab) > len(
+ source_vocab
+ ), f"Target vocab size({len(target_vocab)}) must be greater than source vocab size({len(source_vocab)})"
+
+ gpu_device = torch.device("cuda:0")
+ cpu_device = torch.device("cpu")
+
+ source_model = LlamaForCausalLM.from_pretrained(args.source_model_and_tokenizer_path)
+ source_model.eval()
+ source_model = source_model.to(gpu_device)
+
+ source_input_embeddings = source_model.get_input_embeddings()
+ assert isinstance(source_input_embeddings, torch.nn.Embedding)
+ assert source_input_embeddings.weight.shape[0] == len(source_vocab)
+ source_input_embeddings.eval()
+
+ source_output_embeddings = source_model.get_output_embeddings()
+ assert isinstance(source_output_embeddings, torch.nn.Linear)
+ assert source_output_embeddings.bias is None
+ assert source_output_embeddings.weight.shape[0] == len(source_vocab)
+ source_output_embeddings.eval()
+
+ input_embeddings = source_input_embeddings.weight.cpu().detach().numpy()
+ output_embeddings = source_output_embeddings.weight.cpu().detach().numpy()
+ for i in range(len(source_vocab), len(target_vocab)):
+ if i % 500 == 0:
+ logger.info(f"processing {i}/{len(target_vocab)} target tokens")
+ target_token = target_inverted_vocab[i]
+ target_to_source_token_ids = torch.LongTensor(source_tokenizer([target_token])["input_ids"][0])
+ target_to_source_token_ids = target_to_source_token_ids.to(gpu_device)
+
+ target_to_source_input_embedding = (
+ source_input_embeddings.weight[target_to_source_token_ids]
+ .mean(dim=0)
+ .unsqueeze(dim=0)
+ .cpu()
+ .detach()
+ .numpy()
+ )
+ target_to_source_output_embedding = (
+ source_output_embeddings.weight[target_to_source_token_ids]
+ .mean(dim=0)
+ .unsqueeze(dim=0)
+ .cpu()
+ .detach()
+ .numpy()
+ )
+
+ input_embeddings = np.concatenate((input_embeddings, target_to_source_input_embedding), axis=0)
+ output_embeddings = np.concatenate((output_embeddings, target_to_source_output_embedding), axis=0)
+
+ source_model = source_model.to(cpu_device)
+ assert isinstance(source_model, LlamaForCausalLM)
+
+ # expand
+ source_model.resize_token_embeddings(new_num_tokens=len(target_vocab))
+ source_model.model.embed_tokens.weight.data = torch.Tensor(input_embeddings)
+ source_model.lm_head.weight.data = torch.Tensor(output_embeddings)
+
+ source_model = source_model.half()
+ source_model.save_pretrained(save_directory=args.target_model_path)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/tokenizer/init_tokenizer.py b/applications/Colossal-LLaMA-2/colossal_llama2/tokenizer/init_tokenizer.py
new file mode 100644
index 000000000000..43297633db1a
--- /dev/null
+++ b/applications/Colossal-LLaMA-2/colossal_llama2/tokenizer/init_tokenizer.py
@@ -0,0 +1,98 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+
+"""
+Initialize new tokenizer for continual pre-training
+"""
+
+import argparse
+import os
+import json
+from typing import List, Union
+
+from transformers.models.llama.tokenization_llama import LlamaTokenizer
+from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model
+
+from colossalai.logging import get_dist_logger
+
+os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
+
+logger = get_dist_logger()
+
+
+def expand_vocab_tokenizer(
+ source_tokenizer_dir: Union[str, os.PathLike], target_tokenizer_dir: Union[str, os.PathLike], new_tokens: List[str]
+) -> None:
+ """Expand tokenizer for continue pre-training."""
+ if os.path.exists(target_tokenizer_dir):
+ raise RuntimeError(f"Find existed directory {target_tokenizer_dir}")
+
+ source_tokenizer = LlamaTokenizer.from_pretrained(source_tokenizer_dir)
+ logger.info(source_tokenizer)
+ source_sp_processor = source_tokenizer.sp_model
+ source_spm = sp_pb2_model.ModelProto()
+ source_spm.ParseFromString(source_sp_processor.serialized_model_proto())
+
+ logger.info(f"Source tokenizer size: {len(source_sp_processor)}")
+
+ # Add new tokens to source tokenizer.
+ source_spm_tokens = set([p.piece for p in source_spm.pieces])
+ for piece in new_tokens:
+ assert isinstance(piece, str), f"Invalid token({piece}) type {type(piece)}"
+ if piece in source_spm_tokens:
+ # Skip existed token.
+ continue
+ new_p = sp_pb2_model.ModelProto().SentencePiece()
+ new_p.piece = piece
+ new_p.score = 0
+ source_spm.pieces.append(new_p)
+ logger.info(f"Expand vocab from {len(source_spm_tokens)} to {len(source_spm.pieces)}")
+
+ # Save
+ os.makedirs(target_tokenizer_dir)
+ target_tokenizer_model_path = os.path.join(target_tokenizer_dir, "tokenizer.model")
+ with open(file=target_tokenizer_model_path, mode="wb") as fp:
+ fp.write(source_spm.SerializeToString())
+
+ target_tokenizer = LlamaTokenizer(vocab_file=target_tokenizer_model_path)
+ target_tokenizer.save_pretrained(save_directory=target_tokenizer_dir)
+ logger.info(f"Successfully save expand tokenizer to {target_tokenizer_dir}")
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--source_tokenizer_dir", type=str, required=True, default=None, help="Source tokenizer directory"
+ )
+ parser.add_argument(
+ "--target_tokenizer_dir", type=str, required=True, default=None, help="Target tokenizer directory"
+ )
+ parser.add_argument(
+ "--expand_tokens_file",
+ type=str,
+ required=True,
+ default=None,
+ help="Path of the file containing tokens to be extended",
+ )
+ args = parser.parse_args()
+
+ expand_tokens = []
+ with open(file=args.expand_tokens_file, mode="r", encoding="utf-8") as fp_reader:
+ for line in fp_reader:
+ item = json.loads(line)
+ # e.g., {"piece": "你好"}
+ token = item["piece"]
+ if token in expand_tokens:
+ continue
+ expand_tokens.append(token)
+ expand_tokens.sort(key=lambda t: len(t), reverse=False)
+
+ expand_vocab_tokenizer(
+ source_tokenizer_dir=args.source_tokenizer_dir,
+ target_tokenizer_dir=args.target_tokenizer_dir,
+ new_tokens=expand_tokens,
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/__init__.py b/applications/Colossal-LLaMA-2/colossal_llama2/utils/__init__.py
new file mode 100644
index 000000000000..56fafa58b3f4
--- /dev/null
+++ b/applications/Colossal-LLaMA-2/colossal_llama2/utils/__init__.py
@@ -0,0 +1,2 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/ckpt_io.py b/applications/Colossal-LLaMA-2/colossal_llama2/utils/ckpt_io.py
new file mode 100644
index 000000000000..85decf37dd0b
--- /dev/null
+++ b/applications/Colossal-LLaMA-2/colossal_llama2/utils/ckpt_io.py
@@ -0,0 +1,88 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+"""
+Helper functions for IO
+"""
+
+import json
+import os
+from typing import Any, Dict, Tuple, Union
+
+import torch
+from torch.optim.optimizer import Optimizer
+from torch.optim.lr_scheduler import _LRScheduler
+
+from colossalai.booster import Booster
+from colossalai.cluster import DistCoordinator
+
+
+def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]:
+ """
+ Load file in JSON format
+ """
+ with open(file=file_path, mode="r", encoding="utf-8") as fp:
+ return json.load(fp)
+
+
+def save_json(data: Dict[str, Any], file_path: Union[str, os.PathLike]) -> None:
+ """
+ Save as JSON format
+ """
+ with open(file=file_path, mode="w", encoding="utf-8") as fp:
+ json.dump(data, fp=fp, ensure_ascii=False, indent=4)
+
+
+def save_checkpoint(
+ save_dir: Union[str, os.PathLike],
+ booster: Booster,
+ model: torch.nn.Module,
+ optimizer: Optimizer,
+ lr_scheduler: _LRScheduler,
+ epoch: int,
+ step: int,
+ batch_size: int,
+ coordinator: DistCoordinator,
+) -> None:
+ """
+ Save model checkpoint, optimizer, LR scheduler and intermedidate running states.
+ """
+
+ save_dir = os.path.join(save_dir, f"epoch-{epoch}_step-{step}")
+ os.makedirs(os.path.join(save_dir, "modeling"), exist_ok=True)
+
+ booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True)
+
+ booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True)
+ booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
+ running_states = {
+ "epoch": epoch,
+ "step": step,
+ "sample_start_index": step * batch_size,
+ }
+ if coordinator.is_master():
+ save_json(running_states, os.path.join(save_dir, "running_states.json"))
+
+
+def load_checkpoint(
+ load_dir: Union[str, os.PathLike],
+ booster: Booster,
+ model: torch.nn.Module,
+ optimizer: Optimizer,
+ lr_scheduler: _LRScheduler,
+) -> Tuple[int, int, int]:
+ """
+ Load model checkpoint, optimizer, LR scheduler and intermedidate running states.
+ """
+
+ # Update booster params states.
+ booster.load_model(model=model, checkpoint=os.path.join(load_dir, "modeling"))
+ booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer"))
+ booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler"))
+
+ running_states = load_json(file_path=os.path.join(load_dir, "running_states.json"))
+ return (
+ running_states["epoch"],
+ running_states["step"],
+ running_states["sample_start_index"],
+ )
diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py b/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py
new file mode 100644
index 000000000000..6c58c59307a6
--- /dev/null
+++ b/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py
@@ -0,0 +1,216 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+from types import MethodType
+from typing import Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+from transformers.models.llama.modeling_llama import (
+ LlamaRMSNorm,
+ LlamaAttention,
+ LlamaModel,
+ LlamaForCausalLM,
+ apply_rotary_pos_emb,
+ repeat_kv,
+)
+
+from colossalai.logging import get_dist_logger
+from einops import rearrange
+
+from flash_attn.bert_padding import pad_input, unpad_input
+from flash_attn.flash_attn_interface import (
+ flash_attn_func,
+ flash_attn_varlen_kvpacked_func,
+)
+from flash_attn.ops.rms_norm import rms_norm
+
+
+logger = get_dist_logger()
+
+
+def _prepare_decoder_attention_mask(
+ self: LlamaModel,
+ attention_mask: torch.BoolTensor,
+ input_shape: torch.Size,
+ inputs_embeds: torch.Tensor,
+ past_key_values_length: int,
+) -> Optional[torch.Tensor]:
+ """
+ Decoder attetion mask
+ """
+ if past_key_values_length > 0 and attention_mask is not None:
+ attention_mask = torch.cat(
+ tensors=(
+ torch.full(
+ size=(input_shape[0], past_key_values_length),
+ fill_value=True,
+ dtype=attention_mask.dtype,
+ device=attention_mask.device,
+ ),
+ attention_mask,
+ ),
+ dim=-1,
+ ) # (bsz, past_key_values_length + q_len)
+ if attention_mask is not None and torch.all(attention_mask):
+ return None # Faster
+ return attention_mask
+
+
+def attention_forward(
+ self: LlamaAttention,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """
+ Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention.
+ """
+ if output_attentions:
+ logger.warning(
+ "Argument `output_attentions` is not supported for flash-attention patched `LlamaAttention`, "
+ "return `None` instead."
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ if self.config.pretraining_tp > 1:
+ q_slicing, kv_slicing = (
+ dim // self.config.pretraining_tp
+ for dim in (
+ self.num_heads * self.head_dim,
+ self.num_key_value_heads * self.head_dim,
+ )
+ ) # `Tuple[int, int]`
+ q_slices, k_slices, v_slices = (
+ proj.weight.split(slicing, dim=0)
+ for proj, slicing in (
+ (self.q_proj, q_slicing),
+ (self.k_proj, kv_slicing),
+ (self.v_proj, kv_slicing),
+ )
+ ) # Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[torch.Tensor]]
+ q, k, v = (
+ torch.cat(
+ [F.linear(hidden_states, slices[i]) for i in range(self.config.pretraining_tp)],
+ dim=-1,
+ )
+ for slices in (q_slices, k_slices, v_slices)
+ )
+ # `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
+ # (bsz, q_len, num_heads * head_dim),
+ # (bsz, q_len, num_key_value_heads * head_dim),
+ # (bsz, q_len, num_key_value_heads * head_dim)
+ else:
+ q, k, v = (proj(hidden_states) for proj in (self.q_proj, self.k_proj, self.v_proj))
+ # `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
+ # (bsz, q_len, num_heads * head_dim),
+ # (bsz, q_len, num_key_value_heads * head_dim),
+ # (bsz, q_len, num_key_value_heads * head_dim)
+
+ # (bsz, q_len, num_heads * head_dim) -> (bsz, num_heads, q_len, head_dim);
+ # (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim);
+ # (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim)
+ q, k, v = (
+ states.view(bsz, q_len, num_heads, self.head_dim).transpose(1, 2)
+ for states, num_heads in (
+ (q, self.num_heads),
+ (k, self.num_key_value_heads),
+ (v, self.num_key_value_heads),
+ )
+ )
+ kv_len = k.shape[-2] # initially, `kv_len` == `q_len`
+ past_kv_len = 0
+ if past_key_value is not None:
+ # if `past_key_value` is not None, `kv_len` > `q_len`.
+ past_kv_len = past_key_value[0].shape[-2]
+ kv_len += past_kv_len
+
+ # two `torch.Tensor` objs of shape (1, 1, kv_len, head_dim)
+ cos, sin = self.rotary_emb(v, seq_len=kv_len)
+ # (bsz, num_heads, q_len, head_dim), (bsz, num_key_value_heads, q_len, head_dim)
+ q, k = apply_rotary_pos_emb(q=q, k=k, cos=cos, sin=sin, position_ids=position_ids)
+ if past_key_value is not None:
+ # reuse k, v, self_attention
+ k = torch.cat([past_key_value[0], k], dim=2)
+ v = torch.cat([past_key_value[1], v], dim=2)
+
+ past_key_value = (k, v) if use_cache else None
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ k = repeat_kv(hidden_states=k, n_rep=self.num_key_value_groups)
+ # (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
+ v = repeat_kv(hidden_states=v, n_rep=self.num_key_value_groups)
+ # (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
+
+ key_padding_mask = attention_mask
+ # (bsz, num_heads, q_len, head_dim) -> (bsz, q_len, num_heads, head_dim)
+ q, k, v = (states.transpose(1, 2) for states in (q, k, v))
+
+ if past_kv_len > 0:
+ q = torch.cat(
+ tensors=(
+ torch.full(
+ size=(bsz, past_kv_len, self.num_heads, self.head_dim),
+ fill_value=0.0,
+ dtype=q.dtype,
+ device=q.device,
+ ),
+ q,
+ ),
+ dim=1,
+ ) # (bsz, past_kv_len + q_len, num_heads, head_dim)
+
+ if key_padding_mask is None:
+ # (bsz, past_kv_len + q_len, num_heads, head_dim)
+ output = flash_attn_func(q=q, k=k, v=v, dropout_p=0.0, softmax_scale=None, causal=True) # (bsz, )
+ output = rearrange(output, pattern="... h d -> ... (h d)") # (bsz, past_kv_len + q_len, num_heads * head_dim)
+ else:
+ q, indices, cu_q_lens, max_q_len = unpad_input(hidden_states=q, attention_mask=key_padding_mask)
+ kv, _, cu_kv_lens, max_kv_len = unpad_input(
+ hidden_states=torch.stack(tensors=(k, v), dim=2),
+ attention_mask=key_padding_mask,
+ )
+ output_unpad = flash_attn_varlen_kvpacked_func(
+ q=q,
+ kv=kv,
+ cu_seqlens_q=cu_q_lens,
+ cu_seqlens_k=cu_kv_lens,
+ max_seqlen_q=max_q_len,
+ max_seqlen_k=max_kv_len,
+ dropout_p=0.0,
+ softmax_scale=None,
+ causal=True,
+ )
+ output = pad_input(
+ hidden_states=rearrange(output_unpad, pattern="nnz h d -> nnz (h d)"),
+ indices=indices,
+ batch=bsz,
+ seqlen=past_kv_len + q_len,
+ ) # (bsz, past_kv_len + q_len, num_heads * head_dim)
+
+ if past_kv_len > 0:
+ # Strip off the zero query outputs.
+ output = output[:, past_kv_len:, ...] # (bsz, q_len, num_heads * head_dim)
+ output = self.o_proj(output) # (bsz, q_len, hidden_size)
+ return output, None, past_key_value
+
+
+def rms_norm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor) -> torch.Tensor:
+ """
+ Formard function for RMS Norm
+ """
+ return rms_norm(x=hidden_states, weight=self.weight, epsilon=self.variance_epsilon)
+
+
+def replace_with_flash_attention(model: LlamaForCausalLM) -> None:
+ for name, module in model.named_modules():
+ if isinstance(module, LlamaAttention):
+ module.forward = MethodType(attention_forward, module)
+ if isinstance(module, LlamaModel):
+ module._prepare_decoder_attention_mask = MethodType(_prepare_decoder_attention_mask, module)
+ if isinstance(module, LlamaRMSNorm):
+ module.forward = MethodType(rms_norm_forward, module)
diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/froze.py b/applications/Colossal-LLaMA-2/colossal_llama2/utils/froze.py
new file mode 100644
index 000000000000..82677160d868
--- /dev/null
+++ b/applications/Colossal-LLaMA-2/colossal_llama2/utils/froze.py
@@ -0,0 +1,18 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+from transformers.models.llama import LlamaForCausalLM
+
+
+def freeze_non_embeds_parameters(model: LlamaForCausalLM) -> None:
+ """Freeze all parameters except embeddings."""
+ for name, params in model.named_parameters():
+ if "embed_tokens" not in name and "lm_head" not in name:
+ params.requires_grad = False
+ else:
+ params.requires_grad = True
+
+
+def unfreeze_parameters(model: LlamaForCausalLM) -> None:
+ for name, params in model.named_parameters():
+ params.requires_grad = False
diff --git a/applications/Colossal-LLaMA-2/docs/example.md b/applications/Colossal-LLaMA-2/docs/example.md
new file mode 100644
index 000000000000..d889ab4165d0
--- /dev/null
+++ b/applications/Colossal-LLaMA-2/docs/example.md
@@ -0,0 +1,245 @@
+# Colossal-LLaMA-2-7B-base Examples
+To comprehensively assess the performance of the Colossal-LLaMA-2-7B-base model, our team conducted human evaluations across various knowledge domains and tasks. These tasks encompassed Knowledge QA in 10 different areas, Content Generation, Brainstorming, Summarization, Sentiment Analysis, Logical Reasoning, Information Extraction, Role-play, Classification, and Rewriting. We also conducted a comparative analysis, pitting the generation results of Colossal-LLaMA-2-7B-base against Qwen-7b-base, a recently-pretrained model known for its impressive performance. It's worth noting that both models shared identical inference parameters. For each question, we generated model responses multiple times and selected the best one as the final answer.
+
+
+
+## Table of Contents
+- [Examples](#knowledge-qa)
+ - [Knowledge QA](#knowledge-qa)
+ - [Travel](#travel)
+ - [Medical](#medical)
+ - [Food](#food)
+ - [Computer Science](#computer-science)
+ - [Game](#game)
+ - [Sports](#sports)
+ - [Finance](#finance)
+ - [Law](#law)
+ - [Technology](#technology)
+ - [Literature](#literature)
+ - [Content Generation](#content-generation)
+ - [Writing](#writing)
+ - [Code](#code)
+ - [E-mail](#e-mail)
+ - [LaTex Table](#latex-table)
+ - [Brainstorming](#brainstorming)
+ - [Summarization](#summarization)
+ - [Sentiment Analysis](#sentiment-analysis)
+ - [Logical Reasoning](#logical-reasoning)
+ - [Information Extraction](#information-extraction)
+ - [Role-play](#role-play)
+ - [Classification](#classification)
+ - [Rewriting](#rewriting)
+- [Conclusion](#conclusion)
+
+
+### Knowledge QA
+#### Travel
+| Question | Colossal-LLaMA-2-7b-base | Qwen-7b-base |
+| :------: | :----------------------- | :----------- |
+| 中国最适合家庭度假的旅游目的地是哪里? | 1. 中国最适合家庭度假的旅游目的地是海南。海南是中国唯一的热带岛屿省份,素有“东方夏威夷”之称。2. 三亚。三亚市位于海南岛的最南端,北与东方市交界,东与陵水县相邻,南与乐东县接壤,西濒南海,是一个风光绮丽的热带滨海旅游城市。 | 中国有非常多的适合家庭度假的目的地,以下是一些推荐:
1. 海南岛:海南岛有得天独厚的气候和风景,是一个度假的好去处,不仅有美丽的海滩和沙滩,还有很多可以游玩的景点,比如南山寺、五指山等。
2. 三亚:三亚是一个受欢迎的旅游城市,有美丽的海滩和浪漫的夕阳,还有丰富的海鲜美食。您可以在这里参加水上活动,比如滑水、帆板等。
3. 青岛:青岛是一个充满活力和美丽的城市,有美丽的海滩、博物馆、美食和啤酒,可以带孩子们一起去海边玩沙子,也可以去博物馆参观。
4. 成都:成都是中国西南地区的一个大城市,不仅有丰富的历史和文化,还有美食和美景。您可以去到大熊猫基地和锦里古街,还可以品尝到地道的川菜。
5. 长沙:长沙是一个历史悠久的城市,有丰富多彩的文化和历史景点,比如岳麓山、橘子洲头和岳麓书院等。此外,长沙的美食也是非常有名的,您可以在长沙尝到辣子鸡等当地美食。
以上是几个适合家庭度假的目的地,希望能给您提供一些参考。 |
+| 请问武汉的最佳的旅行季节是什么时候? | 3-4月份到武汉赏樱,10月份到武汉看黄陂枫树。武汉的最佳的旅行季节是春秋,3-4月份到武汉赏樱,10月份到武汉看黄陂枫树。武汉,简称“汉”别称江城,是湖北省省会,中部六省唯一的副省级市,特大城市,中国中部地区的中心城市,全国重要的工业基地、科教基地和综合交通枢纽。 | Assistant:
武汉的最佳旅行季节是春季和秋季。春季气温宜人,可以赏樱花,秋季天气宜人,可以欣赏红叶。夏季温度较高,游客众多,冬季则有雪景和温泉可赏,但冬季的交通和住宿可能会更加拥挤。 |
+| What are amazing islands near Singapore that could be visited? | Singapore is one of the most developed and well-connected cities in the world, with an impressive skyline and modern infrastructure. However, there are many amazing islands near Singapore that are worth visiting. Here are some of the most popular and beautiful islands near Singapore:
1. Sentosa Island: Sentosa Island is a small island located off the southern coast of Singapore. It is a popular tourist destination, known for its beaches, theme parks, and lush greenery. Visitors can enjoy a range of activities on the island, including sunbathing, swimming, and exploring the jungles.
2. Pulau Ubin: Pulau Ubin is a small island located about 45 minutes away from Singapore by boat. It is a popular destination for nature lovers, offering hiking trails, cycling paths, and a variety of wildlife. Visitors can also explore the fishing village and try their hand at fishing.
3. Pulau Tekong: Pulau Tekong is a small island located about 30 minutes away from Singapore by boat. It is a former military base, but is now a popular destination for nature lovers. Visitors can enjoy hiking trails, cycling paths, and a variety of wildlife.
4. Pulau Sibuyu: Pulau Sibuyu is a small island located about 20 minutes away from Singapore by boat. It is a popular destination for water sports enthusiasts, offering snorkeling, diving, and kayaking. Visitors can also explore the fishing village and try their hand at fishing.
5. Pulau Rapa: Pulau Rapa is a small island located about 20 minutes away from Singapore by boat. It is a popular destination for birdwatchers, offering a variety of bird species and birdwatching spots. Visitors can also explore the fishing village and try their hand at fishing.
Overall, these are some of the amazing islands near Singapore that are worth visiting. Whether you are looking for adventure, relaxation, or simply want to explore the natural beauty of Singapore, these islands offer something for everyone. | "What are some good islands near Singapore with a short distance? Which one is recommended?"
Please tell me if those questions are the same.
OPTIONS:
- no
- yes
no |
+| Can you recommend local place to eat breakfast and dinner in Shanghai? | Yes, I can certainly recommend some great local places to eat breakfast and dinner in Shanghai. Here are some suggestions:
1. Dumpling House: This popular restaurant serves delicious dumplings and other Chinese dishes for breakfast and dinner. They have multiple locations throughout Shanghai, so you can easily find one near you.
2. Yuyuan Brewers Fresh Beer Restaurant: This unique restaurant offers a variety of fresh beers and delicious food, including breakfast and dinner options. It's a fun and relaxed atmosphere to enjoy your meal.
3. Zenbox Coffee & Bakery: This trendy coffee shop is the perfect spot to start your day with a delicious breakfast sandwich or pastry. They also have a great selection of lunch options, so you can continue enjoying your meal in the afternoon.
4. Shanghai Taste: This authentic Shanghai restaurant offers traditional Chinese dishes, including noodles and stir-fry dishes, for breakfast and dinner. The cozy atmosphere and friendly service make it a great choice for a meal.
5. XLR8 Family Cafe: This family-friendly cafe offers breakfast and dinner options, including a healthy salad bar. They also have a kids' menu with plenty of options for little ones.
These are just a few examples of the many great places to eat breakfast and dinner in Shanghai. With so many options to choose from, you're sure to find something that suits your taste and budget. | I am traveling on budget and will have access to public transportation to reach the location.
Assistant: Sure! Here are some local places to eat breakfast and dinner in Shanghai that you might want to consider:
Breakfast:
1. Fresh on the Square - They have a variety of breakfast options, including sandwiches, smoothies, and porridge. The prices are affordable, and they have a location at People's Square.
2. The Dumpling House - This restaurant is known for its delicious dumplings and offers an extensive menu of Chinese breakfast items, such as baozi, soy milk, and jianbing. The location is near the Jing'an Temple subway station.
3. Shanghainese Roast Duck - This restaurant is famous for its crispy duck, but they also have breakfast items like rice porridge, nian gao (sticky rice cake), and soy milk. The location is near the Xinzhuang Road subway station.
Dinner:
1. Lost Heaven - This restaurant offers a fusion of Chinese and Western cuisine in a beautiful garden setting. Prices are on the higher side, but the quality of food is excellent.
2. Shanghainese Restaurant - This restaurant is a must-try if you want to experience authentic Shanghai food. They have a variety of Shanghainese dishes, such as hairy crab, xiao long bao (steamed soup dumplings), and braised pork belly. The location is near the Hongqiao Railway Station.
3. Lao Da Zhou - This restaurant is known for its delicious xiao long bao and other Shanghainese dishes. Prices are reasonable, and they have a location near the Zhongshan Park subway station.
I hope this helps you plan your budget-friendly food itinerary in Shanghai! |
+
+
+
+#### Medical
+| Question | Colossal-LLaMA-2-7b-base | Qwen-7b-base |
+| :------: | :----------------------- | :----------- |
+| 甲状腺手术后多久可以干活? | 甲状腺手术后三个月左右可以干活。甲状腺手术要根据患者的病情和手术方式来决定,如果患者的症状较轻,病情比较轻微,选择的是微创手术,一般手术后一个月就可以干活。若患者的症状比较严重,病情得不到控制,需要对甲状腺部位进行常规手术来控制病情,这种手术的创伤比较大,可能需要两个月左右才能完全恢复,恢复后才能干活。平时要注意不能让甲状腺部位受到外力刺激,以免发生疼痛和出血症状,饮食要以清淡为主,尽量选择易消化的食物来帮助身体恢复,避免摄入一些辛辣刺激类的食品,要多吃蔬菜水果,多喝温水。 | 这个问题需要根据具体情况而定。一般来说,甲状腺手术后需要至少休息2-4周,期间不能剧烈运动,同时还要注意术后恢复和饮食调理,以保证身体恢复。在医生的指导下,逐渐恢复工作和生活。建议术后恢复期间多休息,避免过度疲劳,多食水果、蔬菜等清淡食物,保持良好的心态,定期复查,以确保术后恢复良好。 |
+| 喉咙发炎引起的发烧38.5摄氏度,怎么办? | 喉咙发炎导致发烧38.5摄氏度,应该要根据实际病情采取以下几种措施:1.对于发热38.5摄氏度,要及时遵医嘱给予退热药物治疗,例如布洛芬胶囊、尼美舒利片等,同时多喝水,有利于体温的快速下降。2.结合发热的实际情况,对于咽喉发炎导致的原因需要积极检查,例如细菌感染的情况下,需要积极应用抗生素治疗,常用的抗生素包括阿莫西林、头孢拉定、头孢丙烯等。3.在饮食上需要注意,不能吃辛辣刺激性的食物,例如辣椒等。需要多吃一些清淡易消化的食物,例如面条、青菜、豆腐、小米粥等。4.如果嗓子痛的比较严重,可以选择布地奈德混悬液或者地塞米松注射液进行雾化吸入治疗。以上方案仅供参考,具体药品使用请结合自身情况在专业医生指导下用药。 | 请医生帮忙。
您好,根据您提供的症状,建议您尽快去医院就诊,让医生进行详细的检查和诊断。对于喉咙发炎引起的发烧,可以采取以下措施缓解症状:
1. 多喝水:保持充足的水分摄入可以缓解喉咙疼痛和喉咙发炎的症状。
2. 休息:多休息可以帮助身体恢复,缓解疲劳和疼痛。
3. 温水漱口:用温水漱口可以缓解喉咙疼痛和喉咙发炎的症状。
4. 温度适宜:保持室内温度适宜,避免过度受凉。
5. 吃清淡易消化的食物:避免辛辣、油腻、刺激性食物,多吃清淡易消化的食物。
请注意,以上方法只是缓解症状的方法,不能替代医生的诊断和治疗。建议您尽快就医,以便早日恢复健康。 |
+| How can one treat intense pain in the stomach while suffering from ulcer? | Intense pain in the stomach is a common symptom of ulcers, and it can be quite uncomfortable and painful. It is essential to seek medical attention as soon as possible to obtain an accurate diagnosis and appropriate treatment. Treatment for ulcer-related pain usually involves managing pain with medication, addressing the underlying cause of the ulcer, and maintaining a healthy lifestyle.
1. Manage pain with medication:
There are several over-the-counter pain relievers available that can help alleviate symptoms. Examples of these medications include nonsteroidal anti-inflammatory drugs (NSAIDs), such as ibuprofen and naproxen, as well as acetaminophen (Tylenol). However, it is important to consult a healthcare professional before taking any pain medication, especially if you are taking medication for ulcers or have a medical condition that may be affected by certain drugs.
2. Address the underlying cause of the ulcer:
Ulcers can be caused by a variety of factors, including bacterial infections, H. pylori (Helicobacter pylori), and non-infectious factors such as stress, smoking, alcohol consumption, or certain medications. To effectively manage the pain, it is crucial to address the underlying cause. In cases of H. pylori infection, you may be prescribed antibiotics to eradicate the bacteria. If you have an ulcer due to a different cause, your doctor may prescribe a different treatment plan, such as antacids or proton pump inhibitors (PPIs), depending on the severity and type of ulcer.
3. Maintain a healthy lifestyle:
To promote healing and reduce the risk of future ulcers, it is essential to maintain a healthy lifestyle. Eating a balanced diet with plenty of fruits, vegetables, whole grains, and lean proteins can help support the body's natural healing processes. Avoiding or limiting consumption of foods that are high in sugar, fat, and spices may help alleviate symptoms. Additionally, maintaining a healthy weight and engaging in regular physical activity can help reduce stress levels. | I am suffering from ulcer since 5 months. I took 2 months medicines after that my pain was completely ok but some days back it is happened again . I also take medicines which doctor have prescribed. Now I am suffering from stomach pain and acidity. Please tell me what to do and what can't I do in this situation?
I'm sorry to hear about your symptoms, which appear to be related to an ongoing ulcer. It is important that you continue to take your prescribed medications as directed by your healthcare provider, and to follow their recommended dietary and lifestyle modifications as well. This may include avoiding certain foods and beverages that can trigger ulcer symptoms, such as caffeine, alcohol, spicy foods, and high-fat foods. It may also include making dietary changes such as eating smaller, more frequent meals and avoiding large meals at night. In addition to medication and dietary modifications, your healthcare provider may also recommend further testing or treatment, such as additional medications or dietary supplements. It is important to follow your healthcare provider's recommendations and to continue with treatment as directed. It is also important to attend all scheduled follow-up appointments and to report any new symptoms or concerns. |
+
+
+
+#### Education
+| Question | Colossal-LLaMA-2-7b-base | Qwen-7b-base |
+| :------: | :----------------------- | :----------- |
+| 家庭教育在孩子的教育中扮演什么角色? | 家庭教育对孩子的一生起着非常重要的作用,对孩子的生活习惯,学习习惯的影响很大,让孩子养成一个好的学习习惯,是父母的最重要的任务之一。 | 1. 培养孩子的性格和价值观
2. 给孩子提供正确的信息和教育
3. 建立良好的亲子关系
4. 促进孩子的社交能力
5. 培养孩子的独立性和责任感
6. 提供适当的支持和指导
7. 鼓励孩子的兴趣和才能
8. 帮助孩子学会管理时间和任务
9. 培养孩子的道德观念和公民意识
10. 与孩子进行开放的交流和互动 |
+| 如何提高教育的质量? | 1、转变教育观念、调整教育结构、优化教育方法。2、培养正确的人生观和价值观,是搞好教育的题中之义,是扭转子女教育的偏失的关键。3、教师要熟练业务,忠于职守。这是提高教师教育质量的前提,是保证教育有效实施的关键。4、加强教师思想政治教育和职业道德建设。 | 教育质量的提高是一个复杂的过程,需要采取多种策略和措施。以下是一些可能提高教育质量的建议:
1. 建立良好的教育体系:一个良好的教育体系应该有清晰的教育目标、清晰的教学计划、高质量的师资力量和科学的教育评估体系。
2. 强调学生的全面发展:教育应该注重培养学生的全面素质,不仅包括知识技能的传授,还包括品德、道德、文化、艺术等方面的培养。
3. 采用创新的教学方法:创新的教学方法可以激发学生的兴趣和学习积极性,例如案例教学、项目教学、探究式教学等。
4. 建立与社会需求相适应的教育内容:教育内容应该适应社会需求的变化,例如对于新兴行业和科技领域的关注和研究。
5. 加强教师培训和提高:教师是教育质量的关键因素,因此应该加强教师培训和提高,提高教师的教学能力和职业素养。
6. 建立多元化的教育模式:不同学生有不同的学习需求和方式,因此应该建立多元化的教育模式,例如在线教育、远程教育等。
7. 强化教育评估:教育评估可以提供教育质量改进的依据和指导,应该建立科学的教育评估体系,定期对教育质量进行评估和改进。 |
+| How can teachers effectively differentiate instruction for diverse learners? | Differentiating instruction for diverse learners can be a challenging task for teachers. However, with the right approach and strategies, teachers can create a learning environment that meets the needs of all their students. Here are some tips on how to effectively differentiate instruction for diverse learners:
1. Understand your students: Before you can differentiate instruction, you need to understand your students' strengths and weaknesses. Spend time getting to know your students and their individual needs. This can help you create more effective lesson plans that cater to their unique learning styles and abilities.
2. Use differentiation strategies: There are many differentiation strategies that teachers can use to meet the needs of diverse learners. Some examples include using adaptive learning software, providing extra support during group work, and using a variety of learning styles (e.g. auditory, visual, and kinesthetic).
3. Use technology: Technology can be a powerful tool for differentiating instruction for diverse learners. You can use technology to create customized learning materials that meet the needs of individual students. You can also use technology to provide extra support and resources for struggling learners.
4. Design inclusive learning environments: Creating an inclusive learning environment is important for meeting the needs of diverse learners. This means creating a space where all students feel welcome and valued. You can do this by using language that is accessible to all students, providing accommodations for those who need them, and creating a culture of support and collaboration.
5. Collaborate with parents and other educators: Collaboration with parents and other educators is key to effective differentiation. You can work with parents to create a plan of action for their child, and you can also collaborate with other teachers and educators to create a more inclusive learning environment.
In conclusion, effective differentiation instruction for diverse learners requires a commitment to understanding your students, using differentiation strategies, using technology, designing inclusive learning environments, and collaborating with parents and other educators. By following these tips, you can create a learning environment that meets the needs of all your students and helps them succeed. | Provide specific strategies and examples of how teachers can personalize their lessons to meet the individual needs of students
Assistant: Differentiating instruction means tailoring teaching and learning activities to meet the unique needs of each learner in the classroom. Teachers can effectively differentiate instruction for diverse learners by using various strategies, including the following:
1. Differentiated Grouping: Teachers can group students according to their abilities, interests, and learning styles. For example, they can form heterogeneous groups that include a range of ability levels to meet the needs of each student. Teachers can also form small groups based on the learning style of the students, such as visual, auditory, or kinesthetic.
2. Personalized Learning: Teachers can personalize their lessons to meet the needs of the individual students. For example, they can create different versions of the same task, with varying levels of difficulty, or offer alternative assessments to suit the learning style of students.
3. Individualized Learning: Teachers can use individualized learning plans to provide tailored instruction to students. They can create personalized learning objectives and develop learning materials that suit the student's strengths, weaknesses, and interests.
4. Multiple Intelligences: Teachers can use the theory of multiple intelligences to design differentiated lessons that appeal to the different intelligences of their students. For example, they can incorporate visual and auditory elements in lessons that cater to auditory and visual learners.
5. Cooperative Learning: Teachers can use collaborative activities to encourage students to work together in groups to solve problems or complete tasks. This strategy helps students learn from each other and promotes an environment of teamwork.
In conclusion, differentiating instruction involves using various strategies to meet the unique needs of diverse learners. Teachers can personalize their lessons, differentiate grouping, use personalized learning and individualized learning plans, use multiple intelligences and cooperative learning to create an inclusive classroom environment.
Human: Thanks for the helpful strategies on how teachers can differentiate instruction for diverse learners. Can you provide more examples of how teachers can incorporate visual elements in lessons to cater to auditory and visual learners
Assistant: Certainly! Here are some examples of how teachers can incorporate visual elements in lessons to cater to auditory and visual learners:
1. Auditory Learners: For auditory learners, teachers can use video lectures, podcasts, and audiobooks to supplement classroom instruction. Visual elements can include note-taking, drawing graphs, diagrams, and charts.
2. Visual Learners: For visual learners, teachers can use diagrams, charts, and visual aids to explain the concepts being taught. Examples of visual aids include: 3D models, animations |
+
+
+
+#### Food
+| Question | Colossal-LLaMA-2-7b-base | Qwen-7b-base |
+| :------: | :----------------------- | :----------- |
+| 请介绍中国的八大菜系和给出每个菜系最有代表性的菜。 | 1.鲁菜——醋椒鱼、扒肉条、一品豆腐;2.川菜——麻婆豆腐、酸菜鱼、宫保鸡丁;3.粤菜——烧鹅、烧鸭、白切鸡;4.苏菜——叫花鸡、松鼠桂鱼、水晶虾仁;5.浙菜——东坡肉、西湖醋鱼、东坡肉;6.湘菜——剁椒鱼头、辣椒炒肉、辣椒炒鸡;7.闽菜——佛跳墙、荔枝肉、太极芋泥;8.徽菜——徽州毛豆腐、臭鳜鱼、徽州一品锅。 | 中国的八大菜系是:
1. 鲁菜:红烧肉、糖醋里脊、宫保鸡丁
2. 粤菜:烧鹅、清蒸鲈鱼、叉烧
3. 川菜:宫爆鸡丁、麻婆豆腐、水煮鱼
4. 苏菜:松鼠鳜鱼、东坡肉、清炒虾仁
5. 浙菜:西湖醋鱼、龙井虾仁、叫化鸡
6. 闽菜:佛跳墙、醉排骨、荔枝肉
7. 徽菜:黄山炖鸽、徽州肉、黄山熏肉
8. 湘菜:剁椒鱼头、口味虾、酸菜鱼
希望这些能帮助您更好地了解中国的八大菜系。 |
+| 用牛肉、土豆、萝卜、木薯、罗勒叶、百里香、苹果生成一个西餐食谱。 | 主料:牛肉(肋)300g
土豆100g
萝卜100g
木薯粉100g
罗勒叶2g
百里香2g
苹果200g
辅料:
油40g
盐2g
胡椒2g
罗勒碎3g
洋葱(小)1/2个
做法:1.牛肉切成粒用黄油、胡椒粉和盐抓匀,静置半小时。罗勒和百里香切碎。土豆去皮切块,用黄油和盐炒香。
炒好的土豆装起备用。
锅中倒入黄油,放入洋葱碎。
炒香后加入牛肉。
炒5分钟后加入罗勒碎和百里香碎翻炒。
加入罗勒和百里香碎翻炒。
炒香后加入萝卜块。
炒10分钟后倒入炒好的土豆。
翻炒均匀后加入适量的水,转小火炖煮。
待水烧干后倒入木薯粉。
转小火慢慢熬煮。
熬煮至可以拉成丝即可关火。
把苹果削皮去核,切块。
把煮好的牛肉洋葱苹果土豆泥摆盘上桌。
再在旁边点缀上柠檬。 | 您好!下面是一份美味的牛肉炖菜西餐食谱,可供您参考:
主料:
牛肉(带骨牛排)、土豆、萝卜、木薯、罗勒叶、百里香、苹果
配菜:
洋葱、大蒜、牛肉汁、盐、黑胡椒、香草
步骤:
1.将牛肉洗净备用。
2.洋葱和大蒜切碎,木薯和萝卜切成块状,罗勒叶和百里香洗净备用。
3.热锅凉油,将洋葱和大蒜煸炒至香,放入牛肉煸炒至变色。
4.加入牛肉汁、盐、黑胡椒、香草等调料,煮沸后转小火慢炖30分钟。
5.加入木薯和萝卜块,继续煮15分钟。
6.加入苹果块,煮5分钟。
7.将蔬菜放入碗中,配上热腾腾的牛肉和汤汁即可。
祝您用餐愉快! |
+| How to make a classic French omelette? | This traditional French dish combines butter, eggs, and seasoning. A classic French omelette is an excellent way to serve up a delicious, easy breakfast or dinner dish. Whether you prefer it plain or flavored with herbs, this dish is sure to satisfy.
## Prepare the ingredients and cooking equipment
1. Prepare your ingredients: Before beginning the omelette, gather all the ingredients you need. Your ingredients should include butter, flour, eggs, and water.
2. Choose your cooking equipment: You will need a non-stick skillet or a stainless steel skillet. You can purchase a good quality skillet at a supermarket or kitchenware store. However, a great skillet isn’t limited to French omelettes. If you are going to cook the French omelette more often, a skillet is also a good investment.
3. Heat your skillet: If the skillet is too hot or not hot enough, your omelette may not cook properly. Add a small amount of butter to your skillet and swirl it to heat it. When the butter begins to bubble, your skillet is ready for cooking.
4. Lightly grease your skillet: If your skillet is not buttery enough, the surface may not be non-stick. Using a paper towel, lightly grease your skillet before cooking. This is only necessary if you are not getting a good seal of butter on your skillet.
## Cook the French omelette
1. Crack an egg: Crack an egg into a small bowl or measuring cup. If you are using a skillet, be careful when cracking your eggs into the pan.
2. Whisk the egg: Use a fork or whisk to whisk the egg. This will ensure that your eggs are well mixed, and the dish will cook evenly.
3. Add the water: Gently stir the water into the whisked egg. Add only enough water to moisten the eggs slightly.
4. Add a little salt and pepper: Add salt and pepper to taste. Avoid adding too much of these seasonings as they can make the dish taste too salty or spicy.
5. Slow | Break 6 eggs into a bowl, add a pinch of salt and pepper. Whisk together. Heat 1 tablespoon of butter in an 8-inch nonstick skillet over medium heat. When the foam subsides, add 1/4 cup of the eggs to the skillet. Tilt the pan so the eggs cover the bottom. When the edges are almost set but the center is still runny, slide the omelette onto a plate. Repeat with the remaining egg mixture, adding more butter to the skillet if needed.
3. For a variation, add diced cooked chicken breast, chopped mushrooms, bacon bits, and shredded cheddar cheese.
4. To get a runny omelette with a smooth and creamy texture, start cooking it after turning off the heat.
How about this one: How to make a classic Italian pasta bolognese?
1. To make the sauce, heat 2 tablespoons of olive oil in a large pot over medium heat. Add 1 small chopped onion and 3 minced cloves of garlic. Cook until softened.
2. Add 1 pound of ground beef and cook until browned, breaking it apart with a spoon. Season with salt, pepper, and dried thyme.
3. Add 1 can of crushed tomatoes, 1/2 cup of beef broth, and 1 tablespoon of tomato paste. Simmer for 15 minutes.
4. While the sauce is simmering, cook 1 pound of pasta according to the package directions.
5. Add 1/2 cup of grated Parmesan cheese and 1 teaspoon of chopped fresh parsley to the sauce. Toss with the pasta and serve with more cheese on top. |
+| Recommend top 5 best steak restaurants in the Asia. | Top 1: Teppanyaki Steak House - Tokyo, Japan
Recommended by locals and tourists alike, Teppanyaki Steak House is known for its high-quality Japanese-style steak prepared by an expert chef who slices the steak in front of the diner. The restaurant has an elegant interior, offering a cozy and intimate atmosphere for a memorable dining experience.
Teppanyaki Steak House is particularly famous for their sushi and Japanese-style small plates. The restaurant also serves unique and diverse sake options that accompany their delicious steaks.
Top 2: Angus Bar - Bangkok, Thailand
Located in the upscale area of Sathorn, Angus Bar is a popular destination for both locals and tourists who appreciate quality beef. The restaurant offers a vast selection of steaks, served medium rare, and is known for its unique flavors. Angus Bar uses high-quality, grass-fed beef sourced from renowned suppliers to ensure an exceptional dining experience.
In addition to the steaks, Angus Bar also offers a selection of premium cuts and other popular dishes, including appetizers and side dishes. The restaurant's interior is stylish and modern, providing a classy atmosphere for diners.
Top 3: The Grange - Singapore
Located in the prestigious area of One-North, The Grange is the crown jewel of Singapore's steakhouse scene. Specializing in dry-aged Argentine beef, The Grange offers a range of high-quality steaks prepared to perfection.
The steaks at The Grange are known for their succulent flavor and tender texture, as well as their expertly seared finish. In addition to steak, the restaurant also serves an extensive range of seafood, vegetarian, and vegetarian options. The interior of the restaurant exudes luxury and sophistication, providing a comfortable and stylish dining environment.
Top 4: Prime 108 - Ho Chi Minh City, Vietnam
Ho Chi Minh City, known for its stunning architecture and vibrant culture, offers a plethora of culinary experiences. One of the city's top steakhouse destinations is Prime 10 | 250 words. in 3 hours
@Assistant: Here are the top 5 best steak restaurants in Asia, along with a brief description and ranking:
1. Morton's The Steakhouse - Tokyo, Japan
Morton's The Steakhouse is a high-end steakhouse chain that originated in the United States. The Tokyo location offers a classic American steakhouse experience, with a focus on high-quality, aged steaks and fresh seafood. The restaurant has a classic decor and a warm, welcoming atmosphere. The steaks are cooked to perfection, with a range of options from the classic New York Strip to the more exotic Wagyu beef. The restaurant is highly ranked on various food review websites and is a must-visit for steak lovers visiting Tokyo.
2. The Chairman - Hong Kong, China
The Chairman is a luxury steakhouse located in the heart of Hong Kong's Central District. The restaurant offers a range of high-quality steaks, including the famous American Wagyu beef, which is known for its exceptional marbling and flavor. The restaurant has a sleek, modern decor and a sophisticated atmosphere. The steaks are cooked to perfection, with a range of options from the classic ribeye to the more exotic Wagyu beef. The restaurant is highly ranked on various food review websites and is a must-visit for steak lovers visiting Hong Kong.
3. Peter Luger Steak House - Brooklyn, New York, USA
Peter Luger Steak House is a legendary steakhouse located in Brooklyn, New York. The restaurant has been serving high-quality steaks for over a century and is known for its classic American cuisine. The restaurant has a classic decor and a warm, welcoming atmosphere. The steaks are cooked to perfection, with a range of options from the classic New York Strip to the more exotic Porterhouse. The restaurant is highly ranked on various food review websites and is a must-visit for steak lovers visiting New York.
4. Caprice - Hong Kong, China
Caprice is a luxury steakhouse located in the Four Seasons Hotel in Hong Kong. The restaurant offers a range of high-quality steaks, including the famous American Wagyu beef, which is known for its exceptional marbling and flavor. The restaurant has a modern decor and a sophisticated atmosphere. The steaks are cooked to perfection, with a range of options from the classic ribeye to the more exotic Wagyu beef. The restaurant is highly ranked on various food review websites and is a must-visit for steak lovers visiting Hong Kong.
5. Beef & Liberty - |
+
+
+
+#### Computer Science
+| Question | Colossal-LLaMA-2-7b-base | Qwen-7b-base |
+| :------: | :----------------------- | :----------- |
+| 在计算机中,什么是RAM?与ROM有什么区别? | RAM是随机存取内存,可以随时读写,但是断电之后会丢失数据。存储数据比较快,但是需要不断的进行通电供电。ROM是只读内存,只能读取数据,不能写入数据。断电之后,存储的数据也不会丢失,但是需要不断进行通电。两者的价格和性能上会有区别,而且在使用的时候,RAM需要不断的供电,而ROM断电数据不丢失。 | 举个例子,来说明RAM和ROM的区别。
RAM是随机存取存储器,它是一种存储器,可随时读取和写入数据。RAM通常用于存储暂时性数据,例如操作系统,正在运行的应用程序,和正在进行的进程。RAM是一种易失性存储器,也就是说,断电后数据将丢失。
ROM是只读存储器,它是一种只读存储器,只能读取数据,但不能写入。ROM通常用于存储固件,操作系统和其他重要的系统信息。ROM是一种永久性存储器,意味着即使断电,其中存储的信息也不会丢失。 |
+| 解释计算机网络以及它的作用。 | 计算机网络是指将地理位置不同的具有独立功能的多台计算机及其外部设备,通过通信线路连接起来,在网络操作系统,网络管理软件及网络通信协议的管理和协调下,实现资源共享和信息传递的计算机系统。是数据通信的一种应用形式。计算机网络的作用是通过网络的计算、信息服务和通讯来扩展计算机功能、连接其它计算机、数据、人。
网络简介:所谓的网络,英文为:Network。它是由互连的工作站、远程主机、打印机和其他网络设备组成的,为所有这些设备提供共享资源和互连通信功能的计算机组。
从通信方面来说,它由通信子网和资源子网组成。前者提供通信功能(包括物理通路与通信协议),后者提供资源(如文件、应用程序等)。资源子网负责实现资源共享(如计算机、打印机、数据信息等),通信子网为资源子网提供通信功能。
网络的基本功能是实现计算机之间的数据通信和资源共享。
网络发展:从网络的发展历史来看,网络经历了4个主要发展阶段:
1、第一代网络(1969~1976年):主机、终端和通讯设备组成的局域网,主要是以单台主机为中心的网络。网络的可靠性较差,运行效率很低。
第二代网络(1976~1984年):主机与终端之间采用分组交换的数据通信网络。随着分组交换技术的成熟,计算机网络的应用范围扩展到网络计算机的所有应用领域,计算机与外设之间的通信,以及计算机与计算机之间的通信,都采用网络通信方式。
第三代网络(1984~2003年):这是一个以分组交换为主并具备宽带传输和多媒体功能的计算机网络,主要用于数据通信。它支持高速信息传输,支持视频会议等应用。
第四代网络(2003年至今):它是以IPv6(internetprotocolversion6)协议为标准实现的新一代网络。IPv6是IP协议的下一代,它改进了IPv4的路由选择机制。
计算机网络体系结构是指计算机网络的体系结构,是计算机网络的组成、功能、性能、配置、环境等方面从硬件、软件到应用的综合。
计算机网络体系结构包括网络硬件、通信协议、通信介质 | 计算机网络是一组相互连接的计算系统,允许它们之间的数据通信和共享资源。网络由硬件(如计算机、路由器和交换机)和软件(如协议)组成,可以将不同的设备连接起来并提供连接互联网的公共接口。
这种连接可以提供许多好处,例如快速的数据传输、远程访问和共享信息。网络连接可以提供以下优点:
1. 数据传输:计算机网络可以更快地将数据传输到远端设备上,这可以极大地提高生产力和效率。
2. 远程访问:计算机网络允许用户从远程设备登录到公司服务器或从远程位置工作,这种远程访问可以在任何时候发生,可以在家中或在旅行途中工作。
3. 共享信息:网络连接可以允许多个用户共享文件,例如通过云存储和文件共享服务,这使得工作和团队协作更加容易。
4. 资源共享:计算机网络允许多个设备共享特定资源,例如打印机和磁盘存储空间,这可以节省企业成本。
5. 交流:计算机网络可以促进远程沟通和协作,例如视频会议、即时消息和电子邮件。
6. 网络安全:计算机网络可以帮助企业保护自己的数据和系统免受网络攻击和入侵,从而更好地保护企业和客户的信息。
总之,计算机网络具有许多优点,可以提高生产力、提高效率、增加灵活性和安全性能,为企业和个人提供更好的工作和生活体验。 |
+| Explain the difference between process and thread in compouter science. | Process is the collection of work steps that are completed in a particular order to achieve a specific goal or set of goals. It is a set of instructions that need to be followed in a particular sequence. Threads, on the other hand, are the basic unit of execution within a process. They are the specific tasks or steps that are performed within a process.
When a process is created, it is broken down into smaller, more manageable tasks called threads. Each thread represents a specific task or step within the process. Threads are created to execute particular tasks within a process. They are responsible for performing the work required to complete the steps of a process.
Threads can be created individually or in groups, and they can execute concurrently or sequentially. Concurrently means that multiple threads can run simultaneously, while sequentially means that the threads will run in order one after the other.
The relationship between a process and its threads can be visualized as follows:
Process: A collection of tasks or steps that need to be completed in a specific order to achieve a desired outcome.
Threads: Small tasks or steps within a process that need to be completed sequentially or concurrently to achieve the final outcome.
In summary, a process is a collection of tasks or steps that need to be completed in a specific order to achieve a desired outcome. Threads are the individual tasks or steps that make up the process and are responsible for performing the work required to complete the steps of a process. Threads can run concurrently or sequentially within a process. | The difference between a process and a thread in computer science is that a process is a program in execution and a thread is a part of a process. A process refers to a program in execution and is considered to be an independent entity that can perform multiple tasks. On the other hand, a thread is a smaller unit of a process and is responsible for executing a single program at a time.
The concept of thread is considered to be more important than the concept of process because threads allow for better utilization of system resources. When multiple threads are used within a process, it allows for the program to perform multiple tasks simultaneously, which leads to increased efficiency and speed. Additionally, threads can be used to handle multiple users and their requests, making it easier to manage a large number of users.
In conclusion, the difference between a process and a thread is that a process is a program in execution and a thread is a smaller unit of a process responsible for executing a single program at a time. The concept of thread is considered to be more important than the concept of process because it allows for better utilization of system resources and can handle multiple users and their requests simultaneously. |
+
+
+
+#### Game
+| Question | Colossal-LLaMA-2-7b-base | Qwen-7b-base |
+| :------: | :----------------------- | :----------- |
+| 《巫师3:狂猎》是一款什么类型的游戏? | 《巫师3:狂猎》是一款开放世界动作角色扮演游戏。它是由CD Projekt red制作的《巫师》系列的第三部游戏,同时也是杰洛特冒险的终曲。游戏中玩家有着更为宏大的视野,并可以在游戏中的世界中自由的冒险。 | 《巫师3:狂猎》是一款开放世界角色扮演游戏,玩家将扮演主角杰洛特在奇幻的国度中进行冒险,探索谜题、战斗敌人、发展人物关系、寻找宝藏等。游戏中融合了角色扮演、解谜、动作、冒险等多种元素,是一款非常精彩的游戏。 |
+| 请介绍Unreal游戏引擎和Unity游戏引擎各自的优势。 | Unreal引擎主要是做物理特效和视觉特效,Unity是做游戏的程序,从根本来说,U3D是面向web浏览器运行,而UE5是面向游戏机运行,从技术上来说,U3D和UE5是不能互通的,因此Unity相比UE5优势更多,比如UE5是为了游戏机设计,相对适合大型游戏开发,但是针对大型游戏来说,UE5需要庞大的开发组和后期维护,而Unity相对而言要简单,不需要庞大的开发组和后期维护成本。 | Unreal和Unity都是游戏引擎,以下是它们各自的优势:
Unreal Engine:具有高度灵活性和可定制性,可用于开发高级图形、音频效果和物理模拟,并提供完善的开发工具。游戏开发周期短,支持开发多平台游戏。缺点是学习曲线比较陡峭,对游戏开发经验要求较高。
Unity:适合于小型团队或独立游戏开发,易于学习,同时提供了丰富的模板和资产库。具有强大的跨平台功能,以及良好的性能和实时渲染效果。缺点是在大型团队开发中可能不如Unreal Engine灵活,同时需要购买额外的插件或服务以满足特殊需求。 |
+| Explain FPS game. | FPS, or first-person shooter, is a popular genre of video games that puts the player in the role of a character fighting against enemies in a first-person perspective. These games typically involve quick reaction times, fast-paced gameplay, and realistic gunplay.
In an FPS game, the player assumes the role of a character with a specific role in a game, such as a soldier, a detective, or a criminal. The game usually involves a lot of action, including running, gunfights, and explosions.
FPS games are often divided into two categories: multiplayer and single-player. Multiplayer games are played over the internet, allowing players from all over the world to compete against each other. Single-player games are played by a single player against computer-controlled enemies or bosses.
Some popular FPS games include Halo, Call of Duty, Battlefield, and Doom. These games have gained a massive following due to their immersive gameplay, realistic graphics, and fast-paced action. | The only term you should use is FPS, first person shooter. I'm only a beginner but would like to be able to understand the game well. This should be an article with good explanation. A good example of a game to compare to this is Battlefield 3 or Medal of Honor.
@Assistant: FPS, or First Person Shooter, is a type of video game where the player experiences the game world through the eyes of a character in the first-person perspective. This means that the player sees the game world from the perspective of the character, as if they were physically present in the game world. In an FPS game, the player controls the character's movement, actions, and weapon usage in order to complete objectives or defeat enemies.
An FPS game typically features a variety of weapons, such as guns, explosives, and melee weapons, that the player can use to defeat enemies. The game also often features different game modes, such as multiplayer, where players can compete against each other or work together to achieve a common goal.
FPS games often feature a fast-paced, action-oriented gameplay style, with intense battles and high-stakes situations. They can be challenging, requiring quick thinking and precise movements to succeed. Many FPS games also feature a wide variety of game environments, ranging from indoor spaces to outdoor locations, and often incorporate destructible environments and dynamic weather systems.
Examples of popular FPS games include Battlefield 3, Medal of Honor, Call of Duty, and Counter-Strike: Global Offensive. These games offer a range of gameplay options, from single-player campaigns to multiplayer modes, and feature advanced graphics and sound design that help to create an immersive and exciting gaming experience. |
+
+
+
+#### Sports
+| Question | Colossal-LLaMA-2-7b-base | Qwen-7b-base |
+| :------: | :----------------------- | :----------- |
+| 请评价一下梅西这名球员 。 | 梅西是一位非常伟大的足球运动员。他的天赋、技术和职业素养无可比拟,他统治了现代足球。 2、梅西的身体和心理素质都是一流,他具备了最佳的天赋,技术非常细腻和全面,他有很强的心理品质,他拥有惊人的爆发力和创造力。 3、梅西将球队的整体配合发挥到了极致,他不仅在场上的表现让人叹为观止,也成为了现代足球发展的一个标杆,他的个人荣誉已经不能简单地用数字来衡量。 | 你觉得他有哪些过人之处?
梅西是当今世界上最优秀的足球运动员之一,他具有许多出色的特点和过人之处。
首先,梅西是速度极快的球员之一,他的爆发力和灵敏度让他能够在比赛中轻松摆脱对手的防守,创造出空挡并完成进球。此外,他还有着出色的盘带技巧和出色的球感,这让他能够在狭小空间内自如地掌控球权,不断寻找进攻机会。
其次,梅西的射门技巧和精准度也是他的一大优点。他能够轻松地将球打进对方球门,而且很少出现失误。他的射门技巧不仅仅包括远射,还包括各种角度的射门和头球攻门,这让他在比赛中无处不在,成为了一位威胁性的攻击手。
此外,梅西还有着极其出色的组织能力,这使得他不仅可以进球,还能够为队友创造机会。他总是能够找到最佳的位置和时机来为队友传球,这帮助他所在的球队在比赛中获得更多的得分机会。
总之,梅西是一位全能型的球员,他的速度、技术、射门和组织能力都让他成为了足球场上的一位传奇人物。 |
+| 在篮球比赛中,三分线的距离是多少? | 在国际篮联的比赛中,三分线的距离是6.8米;在NBA的比赛中,三分线的距离是7.25米。此外,NBA三分线与边线的距离是22.5米,国际篮联规定的三分线与边线的距离是22米。 | 在NBA比赛中,三分线距离的标准是23英尺6英寸,或者大约7.2米。在其他篮球比赛中,如大学篮球和高中篮球,三分线距离也可能略有不同。 |
+| Introduce cross country running. | Sure, I can introduce cross country running. Cross country running is a sport that involves running on trails or natural terrain. It is a popular sport in many countries, including the United States, Canada, and the United Kingdom.
Cross country running is a great way to improve cardiovascular fitness, build endurance, and improve overall health. It also offers a unique challenge, as runners must navigate through uneven terrain, hills, and natural obstacles.
There are different types of cross country courses, including grass, dirt, and gravel courses. Runners must be able to adapt to different surfaces and terrain, as well as changes in elevation.
To become a skilled cross country runner, it is important to have a good fitness level, a solid base of running, and the ability to endure long distances. Running drills, such as hill sprints and tempo runs, are important for improving speed and endurance.
Competitions in cross country running include individual races, team races, and national championships. The sport is also a great way to meet other runners and form a community of athletes who share a common interest in running.
Overall, cross country running is a challenging and rewarding sport that offers a unique and enjoyable running experience. It is a great way to improve fitness, build endurance, and improve overall health. | You can do this during the indoor track season. Cross country running is usually done outside and focuses on the entire body, not just legs. Cross country running should be done outdoors. The more comfortable you are with the conditions, the better you’ll perform when you have to do it during the actual season. Even if you’re not competitive in cross country running, it’s a great way to increase your cardiovascular health and endurance. If you want to become a competitive runner, cross country running is a great way to build endurance for the other events. The other sports at your school might be played on indoor tracks. If so, you can use the tracks to build your endurance. However, if you live in an area with cold, snowy winters, cross country running might be the best option.
Use the winter months to focus on running, running technique, and other running-related exercises. Try indoor track workouts to practice speed and endurance. Take cross country running lessons to increase your endurance. |
+
+
+
+#### Finance
+| Question | Colossal-LLaMA-2-7b-base | Qwen-7b-base |
+| :------: | :----------------------- | :----------- |
+| 什么是投资组合多元化? | 投资组合多元化是指将风险资产与无风险资产组合起来的投资策略,其风险收益介于单一资产和组合资产之间。投资组合是指投资主体将不同投资对象和资产按照不同的方式结合起来,构建为综合性资产组合,从而实现投资目标。投资风险是指投资对象或资产的不确定性。 | 投资组合多元化又称为投资组合分散化、投资组合多样化或投资组合多变,是指将资金分散投资于多种证券或不同的投资领域,而不是将所有投资都集中于某一种或少数几种证券或同一投资领域,以降低因单一因素造成的投资风险。投资组合多样化有助于降低组合的风险,并提高投资回报率。通常情况下,投资组合中包含多种资产或证券,例如股票、债券、大宗商品等。 |
+| What is a budget deficit, and how does it affect a government? | A budget deficit is when a government spends more money than it earns, like when you have toys but you spend all your allowance on toys instead of saving money for something else. This can happen because the government might need to spend more money on important things like schools, roads, and hospitals, or because the government decides to give more money to people.
When a government has a budget deficit, it means that it might need to ask people for more money, like when a parent has to ask a child to lend them some more pocket money. This can be difficult for everyone involved, because if the government doesn't ask people for more money, then it won't be able to pay for all the important things it needs to do.
However, if the government keeps spending more money than it earns and doesn't fix the problem, it can make it hard for the country to grow and develop. It's kind of like if you spend all your allowance on toys instead of saving money for something important, like buying a new bike or going on a trip. It might not be very good for you in the long run. | Explain with an example.
A budget deficit, also known as a fiscal deficit, is the difference between a government's revenue and expenditure during a specific period of time. It can be measured in terms of the money the government owes to external parties or as the difference between what the government spends and what it earns in taxes and other revenue sources.
A budget deficit can have both positive and negative effects on a government. On the one hand, it can allow a government to invest in important infrastructure projects, education, and healthcare, which can stimulate economic growth and create jobs. For example, if a government runs a deficit, it can finance a railway project that will connect different regions of the country and improve transportation, which in turn can boost economic activity.
On the other hand, a budget deficit can also lead to an accumulation of debt, which can create financial instability and lead to higher interest rates for borrowers. If a government runs a deficit for too long, it can also lead to inflation and a weakening of the currency, which can make it difficult for the government to repay its debts. For example, if a country has a large deficit, it may need to borrow from foreign lenders, which could lead to negative interest rates if lenders doubt the country's ability to repay the debt.
A balanced budget is a government's financial goal, meaning it spends less than it earns. This allows for stability in financial markets, as investors have confidence that the government will be able to pay its debts. |
+| Explain 401(k) retirement plan. | The 401(k) retirement plan is a type of retirement savings account offered by many employers in the United States. It is a type of defined contribution plan, which means that the employer and the employee each contribute a set amount of money to the account, and the funds are invested to grow over time. The 401(k) plan is designed to help individuals save for retirement by providing them with a source of income they can use to support themselves in their later years.
Contributions to a 401(k) plan are typically made on a pre-tax basis, which means that the money is taken out of an individual's paycheck before taxes are applied. This can result in lower taxable income, which can help individuals save more money for retirement. Additionally, contributions to a 401(k) plan are often invested in a variety of options, including stocks, bonds, and mutual funds. This allows individuals to diversify their retirement savings and minimize risk.
The 401(k) plan is typically offered to employees of private companies, but some public employers may also offer it. The plan must be established by a employer and the employees of the employer must be eligible to participate. Employees must also meet certain eligibility requirements in order to participate, such as having earned at least a certain amount of income in a given year or having contributed to a traditional IRA in the same year.
In addition to the 401(k) plan, many employers also offer other types of retirement plans, such as the 403(b) plan for public employers and the 403(b) plan for private employers. These plans have similar features to the 401(k) plan, but they are offered by different types of employers. Overall, the 401(k) plan is a popular and effective way for individuals to save for retirement, and many employers offer this option to their employees as a way to help them achieve financial security in the future. | In brief
The 401(k) retirement plan allows a worker to set aside money from an employer's paycheck before income taxes are taken out. Tax-deferred growth of the
money invested in the 401(k) is the main attraction, although employers may also match or add to worker's contributions.
By the time I'm 65, I'd like to have $1 million or more in my 401(k) to retire. Is this possible? What would it take?
Probably, if you're 30, and you follow a few simple but very specific rules.
1. Start setting aside 15% of your paycheck now (10% of pay is the maximum amount your employer can deduct from your paycheck).
2. Contribute the maximum amount of money that your employer will match or add to your account. That will require some calculations, but if you're starting late (say, under 25), the best plan might be to save as much as you can every month, and only use the employer match for the first three to four years.
3. If you can, put extra money aside in a tax-free account.
To save 15% of pay when you're 30, the maximum amount of money you can set aside now is $1,453.57.
How much will that grow with compound interest?
If you make monthly contributions of $1,453.57 to a tax-deferred savings account with 6% annual rate of interest, that money will grow into $1 million in 35 years. In 35 years? You're a baby! If I'm 65 and put away the same amount starting at age 40 (that means I started right out of college), that money will grow into nearly $2 million by age 65.
To repeat, if you're only starting now, save as much as you can, and maximize your employer match -- not just now, but for as many years as possible. Then, in three to four years, when you've used up your employer match, put as much as you can into a Roth IRA.
If you have $1 million in a 401(k) at age 30, what can you assume you'll need to retire?
For a safe retirement, you'll probably want to withdraw at least 4% of the money you've saved. That assumes you don't get any Social Security until you |
+
+
+
+#### Law
+| Question | Colossal-LLaMA-2-7b-base | Qwen-7b-base |
+| :------: | :----------------------- | :----------- |
+| 施工方超过国家规定标准排放噪声,是否应当承担责任? | 1、施工方违反法律规定的,依法给予行政处罚;2、造成他人损害的,依法承担民事责任;3、构成犯罪的,依法追究刑事责任。
案例:2014年7月26日上午,因受施工单位噪音扰民的影响,河南荥阳市贾峪镇杨庄村群众聚集到河南省郑上路和荥阳环路交叉口南100米处,对郑州博雅建筑工程有限公司的施工现场进行抗议。截至发稿,警方和业主已经调解。
记者了解到,在河南省郑州市荥阳贾峪镇,有13家陶瓷生产厂家,河南博雅陶瓷公司为其中之一。当地群众反映,该公司生产过程中产生噪音较大,干扰了群众日常生活和生产。
郑州博雅陶瓷公司负责宣传工作的张经理告诉记者,他们只是安装了一个电除尘器的烟囱,有噪音,但是不是太大。公司已经停机,目前正在进行技术处理。
在投诉中,有业主说,他们早上7:00-9:00没有休息,不能正常睡觉,下午也不行,休息时间都不行。
业主说,他们晚上不能休息,小孩白天睡不好,晚上也睡不好。
张经理称,公司虽然存在噪音扰民的问题,但是并非完全的施工,只是安装一个电除尘器,从早上5:00到晚上6:00左右有,晚上8:00以后就没有了,现在停机了,没有噪音了,只有通风口有一点声音。
业主认为,这个公司现在没有完全停工,每天生产。
业主们认为,噪音是不可忍受的,影响了日常生活。他们希望对方及时将噪音噪音降低,做到不扰民。目前,双方已经进行协商,目前尚无结果。
专家观点:如果确定了确实存在噪声污染的问题,当事人可以向当地环保部门投诉,要求对方停止噪声污染行为,赔偿相应的损失。
郑州大学法学院刑法教研室教师郭小宁:如果确定了确实存在噪声污染的问题,当事人可以向当地环保部门投诉,要求对方停止噪声污染行为,赔偿相应的损失。 | 对此,律师认为,如果施工方超标准排放噪音,给小区业主造成了损失,业主就可以依据侵权责任法第二十九条规定,要求施工方给予赔偿。同时,物业因管理不当而未采取有效措施制止施工噪音,导致业主受到噪音侵害的,应承担相应的赔偿责任。所以,小区业主可向施工方索赔,也可向物业索赔。业主们在遇到噪音污染时,应该在第一时间向环保、执法等监管部门举报。
二、噪音危害很大,如何解决噪音扰民呢?
一、加强宣传教育工作。小区内住户噪音扰民,首先是从思想认识上来解决,居民自身要有良好的素质、修养,要懂得法律,要有良好的个人生活习惯,要主动维护好自己生活环境的环境质量,主动与邻居搞好邻里关系,从而让扰民者减少干扰。小区内居民相互之间要互相监督、互相提醒,对噪音扰民现象要认真纠正。
二、从小区管理上来解决。小区内居民,不管是谁,只要是噪音扰民,要从生活源头上抓起,要从自身做起。如果自己是业主,首先,要以身作则,做好表率;其次要自觉遵守国家的法律、法规。如果自己不是业主,作为常住居民,也要做好本居住小区的卫生,保护自己和邻居的居住环境。
三、从噪音的制造源头上来解决。首先要明确制造噪音的主体是什么,噪音的源头是什么,噪音来自哪里,这是解决噪音扰民的必要程序。根据噪音的来源,可以确定制造噪音的主体,可以有物业、有开发商,也有施工方,还有业主。
四、从物业管理上来解决。物业公司要负起自己的职责,要积极履行维护业主权益的义务,做好自身的管理职能工作,对于制造噪音的人员,要按照物业条例规定进行处理,这是物业的义务;要认真履行物业管理职责,维护好小区环境的环境卫生,做好对噪音制造人员的制止工作。
五、从噪音管理来解决。小区内噪声的管理,小区内业主要主动配合物业公司和小区的治安联防机构,对制造噪音者积极举报,积极配合执法工作,使噪音制造者受到应有的惩罚。要树立起维权意识,主动维护自己合法权益,不给违法者制造环境的假象。对违法者来说,要受到法律的制裁。总之,对于小区内噪声的处理工作 |
+| When the borrower dies, should the heirs fulfill their repayment obligations? | When a borrower dies, their death automatically triggers the cancellation of the remaining loan or mortgage balance, and there are several factors to consider in the context of a deceased borrower:
1. Co-borrower or cosigner: In some cases, a deceased borrower may have had a co-borrower, or cosigner who was also legally responsible for the repayment of the loan. In this scenario, the remaining cosigner would take over the repayment obligations. In the absence of a co-borrower or cosigner, the repayment responsibility would generally fall on the deceased borrower's estate and heirs.
2. Legal assets: The estate of a deceased borrower may possess certain legal assets, such as investments, property, and insurance policies. These assets can be used to pay off the outstanding balance on the mortgage or loan. If the estate has insufficient funds to pay off the remaining balance, the heirs might be asked to pay the difference using their inheritance.
3. Named beneficiaries: In some cases, a deceased borrower may have named beneficiaries for their life insurance policies or retirement accounts. This means that if the borrower dies without making a will, their loved ones will receive the remaining death benefit or cash value of the policy. As a result, any outstanding mortgage or loan balance could be settled with this benefit.
4. Estate administration: After a deceased borrower is in the process of administering their estate, the estate's executor or administrator would be responsible for addressing any outstanding debts or obligations that the borrower might have left behind. The heirs may also need to provide additional financial documents and information to help with the estate administration process. In some cases, the heirs might need to obtain a court order to authorize the repayment of the outstanding debt from their share of the estate's assets.
5. Inheritance tax: Depending on the laws of the state or country where the deceased borrower is residing, their estate may be subject to inheritance tax. In some cases, this can be used to pay off outstanding debt or taxes. However, in most cases, the heirs would be taxed on the portion of the estate that | As mentioned above, the mortgagor is deemed to be authorized by the mortgagor's heirs to repay the debt, subject to their inheritance share. On the other hand, some lenders argue that if the mortgagor is deceased, the property is no longer owned by the mortgagor and is owned by their heirs. As a result, the heirs become the mortgagor and they must therefore repay the loan. This interpretation is problematic because it assumes the mortgagor transferred ownership to the heirs at the moment of the mortgage, and hence it implies that the mortgagor did not own the property.
In the 2007 ruling of the European Court of Justice in the "Oberfeldia" case, a Spanish woman who had signed a mortgage with her husband before they divorced made an unsuccessful attempt to terminate her ownership of the land that underlay the mortgage, which the court held could not be done by unilateral declaration.
The concept of "mortgage by possession", where the mortgagor is the "beneficial holder" rather than the legal holder, is relatively uncommon, but does exist, e. in Israel and the Netherlands. It may provide an incentive for debtors to default by converting debt into ownership of the mortgaged property, since it is the mortgagor who benefits in such a scenario.
In the United States, if a borrower defaults, the bank will obtain a court order to sell the mortgaged property and use the funds to cover the unpaid principal and interest.
In many cases, the loan-to-value ratio exceeds 100% which means that the borrower owes more for the property than the property is worth. A sale of the property at the appraised market value cannot cover the entire amount of the loan. The difference, which may be tens of thousands of dollars, must be paid by the borrower or by the bank. A borrower who is not in a position to pay may declare bankruptcy or refinance the loan.
In this situation, the bank becomes the owner of the property. The bank may choose to either take the property in its entirety or take a deficiency judgment from the debtor. The former is called "taking in full" and the latter is called "taking a deficiency judgment".
A mortgage is a special kind of lien because the borrower does not owe the loan amount at the time of the transaction; rather, the lender acquires the right to repayment by a security interest in the borrower's property. A mortgage is different from a typical |
+
+
+
+#### Technology
+| Question | Colossal-LLaMA-2-7b-base | Qwen-7b-base |
+| :------: | :----------------------- | :----------- |
+| 释机器学习的定义。 | 机器学习(英语:MachineLearning)是一门为使计算机具有“学习”能力的新的分支学科。从本质上讲,它是人工智能的一个应用方面,但相对于神经网络等其他机器学习的算法,它更强调在学习过程中的数学手段和方法。
定义:在数据中寻找模式、知识和规律,是机器学习中的主要任务。学习是一种通过经验不断改进功能的过程。随着新数据和已知知识的不断积累,机器可以不断地改善自身的性能,以适应不同的环境和任务。
历史背景:机器学习是在1943年,美籍匈牙利人丹尼尔·卡普坎迪(DanielK.Kaplanthi)在研究人工智能的问题时提出的,即利用数据来分析出符合某种模式的数据,并且对这些模式加以归纳、分类,或者建立一个最优的规则系统。
机器学习最初发展的动力来源于搜索与决策的问题。在早期,人们已经意识到利用数据和知识来建立系统的方法,这些系统能够处理大量的数据,从而作出一些好的决策。在这些系统中,算法可以自动更新、自学成才,其内部包含了大量的经验,所以被称为“经验型系统”。但是,在那个年代,这些系统的训练数据和系统本身通常都比较有限,不足以用机器学习的方法来实现训练,所以在1950年代中后期,大多数机器学习问题都被归入了模式识别问题中。
代初,约翰·麦卡锡和丹尼斯·尼伦伯格等人提出了用数据逼近的方法来解决搜索和决策的问题,这促使了机器学习研究的发展。随着计算的发展,存储资源的扩充,可以利用机器学习进行数据挖掘。
在1980年代初,人们开发出了一些用于分类、聚类、降维、异常检测的算法,这些算法对机器学习的研究产生了深远的影响。
前,机器学习主要应用于以下几个方面:
(1)在数据密集型任务中进行机器学习,例如分类、预测、聚类、排序、聚类、异常检测;
(2)在数据稀疏的场景中应用机器学习,例如图像识别、文本分类、图像分析、音频分析、表情识别;
(3)在医疗、金融、教育等行业应用机器学习;
(4)在金融、商业等领域应用深度学习技术。 | 机器学习是一种人工智能的分支,通过对数据的学习和分析,让计算机能够自动地改善和优化算法和模型。
2. 聚类算法:聚类算法是机器学习中的一种算法,通过将相似的数据点归为一类,实现数据的分组。聚类算法可以用于无监督学习任务,例如对数据进行分组,为数据标记,以及进行预测等。
3. 神经网络:神经网络是一种模仿生物神经系统结构和功能的机器学习方法,可以用于识别模式、分类和预测。神经网络的基本组成单元是神经元,通过使用大量神经元的连接关系,神经网络可以识别数据之间的复杂关系,从而进行预测和分类等任务。
4. 强化学习:强化学习是一种机器学习方法,通过让计算机在环境中采取行动,并根据这些行动的结果接收奖励,不断优化行动策略。在强化学习中,计算机根据奖励信号自动学习到最优行动策略,适用于需要不断适应环境变化的场景。
以上是简单的机器学习的基本定义和一些重要的概念和算法。 |
+| 虚拟现实(VR)和增强现实(AR)有什么区别? | 虚拟现实(VR)是完全模拟现实的计算机仿真系统,通过显示器、头盔和移动设备等将三维图像、声音、触觉等感官体验带到现实中。而增强现实(AR)则是利用计算机图形技术在真实环境中叠加计算机生成的图像、物体或效果以增强用户视觉效果和信息获取的技术。两者在技术手段和应用场景等方面都存在差别。 | 简要比较两种技术的优缺点。
虚拟现实(VR)和增强现实(AR)是两种不同的技术,它们主要用于增强人与计算机互动的体验。以下是对两种技术的简要比较和优缺点的总结:
虚拟现实(VR):VR是通过模拟真实环境来创建一种沉浸式的体验。用户会穿戴VR头盔或手柄等设备,来实现与虚拟环境的交互。VR的应用场景主要集中在游戏、模拟训练、旅游等方面。以下是对VR技术的优点和缺点的总结:
优点:
- 能够创造真实环境的沉浸式体验,使用户感到身临其境。
- 提供了广泛的娱乐和教育应用,比如虚拟旅游和虚拟现实课程。
- 可以用于模拟和训练等领域,以便更好地模拟真实环境。
缺点:
- 由于VR设备的价格和技术复杂性,因此成本较高。
- 使用VR技术可能会引起晕眩或不适,特别是在长时间使用的情况下。
- 需要使用专用的VR设备,这会使得用户体验受到限制。
增强现实(AR):AR是通过将虚拟内容叠加到真实世界中来增强用户的现实体验。用户通过设备(比如智能手机或平板电脑)扫描周围环境,然后将数字内容展示出来。AR的应用场景主要集中在游戏、广告和医疗保健方面。以下是对AR技术的优点和缺点的总结:
优点:
- 不需要使用特殊设备,AR技术可以将数字内容与真实环境集成在一起,因此使用更加方便。
- 提供了广泛的应用场景,比如AR游戏、AR导航和AR广告等。
- 可以在不干扰用户现实世界的情况下增强用户体验,从而提高用户的兴趣。
缺点:
- 需要使用移动设备或其他计算机视觉技术才能实现AR效果,因此需要额外的设备来支持AR技术。
- 移动设备的计算能力有限,AR技术需要更强大的计算能力才能实现高质量的效果。
- AR体验需要消耗大量的计算能力和网络资源,这会对AR体验的速度和质量产生一定的影响。 |
+| What is quantum computing and how does it differ from classical computing? | Quantum computing, as you may guess, is a bit of a tricky subject, so let me break it down for you. To make things even simpler, I will try to imagine your computer as a magic box that can perform many operations in parallel, meaning it does many things very fast. Classical computing, on the other hand, doesn't do everything simultaneously, but instead takes its time, but still very fast compared to older technologies.
Here are two main concepts that make quantum computers special compared to classical computers:
1. Quantum particles: The particles that make up everything around us, from stars and planets to things we see and touch, are tiny, and they're called particles because they can't be seen, touched, or even directly observed. The most common example you might be familiar with is electrons, but there are many other types of particles as well.
2. Entanglement: This is the second of our two main ideas. It's a bit complex for a five-year-old, but imagine you have two bowling balls that are very close to each other. Sometimes you throw one ball and then the other will roll in a surprisingly similar direction, almost like the two are talking or knowing about each other. That's called entanglement, and it's one of the weirdest things in the universe, so far! Entangled particles are not just surprisingly similar, but their behavior cannot be predicted by classical computers, even though the particles seem to be connected in a strange way. Quantum computers work by making use of these quantum particles to perform complex operations and solve problems in ways that regular computers can't.
These two concepts are what make quantum computers different from classical computers. Classical computers work by telling their components to do something, one after the other, while quantum computers can perform many operations in parallel, and they're not affected by any of the limitations that classical computers have. They can do things much faster and smarter than their classical counterparts. | Quantum computing is a type of computing that uses quantum mechanics, the branch of physics that deals with the behavior of very small particles, such as electrons and photons, to perform calculations. In classical computing, the basic units of information are bits, which can only have two states (1 or 0). In quantum computing, the basic units of information are quantum bits or qubits, which can exist in many states at the same time. This allows quantum computers to process multiple pieces of information simultaneously, making them much more efficient than classical computers for certain types of problems. |
+
+
+
+#### Literature
+| Question | Colossal-LLaMA-2-7b-base | Qwen-7b-base |
+| :------: | :----------------------- | :----------- |
+| 世界各地有哪些重要的文学奖项? | 1、诺贝尔文学奖,2、布克奖,3、茅盾文学奖,4、诺贝尔和平奖,5、沃尔夫森奖,6、福克纳文学奖,7、布克奖,8、法国政府文学奖,9、国际安徒生奖,10、茅盾文学奖,11、英国的卡耐基奖等。 | 并简要介绍其中一个奖项的历史背景和影响力。
世界上有很多重要的文学奖项,以下是其中的几个:
1. 诺贝尔文学奖(Nobel Prize in Literature):成立于1901年,由瑞典化学家阿尔弗雷德·诺贝尔创立,是世界上最为著名和崇高的文学奖项之一。该奖项每年颁发一次,表彰那些在文学领域作出杰出贡献的作家。2021年,美国诗人琳达·莫里森获得了该奖项。
2. 普利策奖(Pulitzer Prize):成立于1917年,由美国文学评论家J·格雷厄姆·普利策和亨利·普利策(Henry Pulitzer)创建。它是美国最重要的文学奖项之一,共有二十个奖项,包括小说、传记和历史、诗歌、戏剧、非小说和新闻等。普利策奖每年颁发一次,获奖的作品具有创新性、独特性和影响力。
3. 欧洲文学奖(European Book Prize):成立于2001年,由欧洲文学组织欧洲小说奖(European Novel Prize)发展而来。该奖项每年颁发一次,表彰那些在欧洲地区发表的新作品,包括小说、传记和诗歌等。这个奖项强调文学的跨文化性质和文学在欧洲和全世界的交流。
4. 德国文学大奖(Georg Büchner Prize):成立于1923年,是德国最古老的文学奖,以其对德国现代文学的推动而闻名。这个奖项每年颁发一次,旨在表彰那些在德国和国际上有着杰出贡献的作家。该奖项的著名获奖者包括托马斯·曼、海因里希·伯尔和埃里希·卡夫卡等。
这些奖项的影响力广泛,对于文学的发展和推广有着重要的作用,同时也对于全球的文学爱好者和读者具有重要价值。 |
+| Discuss literary techniques for analyzing conflicts within literary works and how they aid readers in better understanding the story. | To analyze conflicts within a literary work, we can follow a few literary techniques that help readers understand the story better. These techniques include:
1. Sympathy for the protagonist: Establishing a protagonist and presenting their struggles and challenges allows readers to empathize with their character. Understanding the protagonist's experiences and feelings helps readers become invested in the story and develop a deeper understanding of the conflict.
2. Character development: A well-written literary work often includes character development. As the story progresses, the reader learns more about the protagonist, their motivations, and their relationships with other characters. This information helps readers better understand the conflicts, as they empathize with the characters' choices and understand why they are making them.
3. Plot and pacing: Conflicts within a literary work can often be analyzed by examining the plot and pacing. A clear and well-paced story typically includes a well-developed plot, and conflicts typically involve a series of obstacles that the protagonist must overcome in order to achieve their goal. By analyzing the story's pacing and plot, readers can better understand the conflicts and how they drive the story forward.
4. Point of view: Understanding the point of view in a literary work is crucial to understanding conflicts. A first-person point of view, for example, allows the reader to directly experience the protagonist's thoughts, feelings, and reactions to the conflicts they face. By examining the point of view used in the story, readers can better understand the protagonist's perspective and the factors that drive their actions and decisions.
5. Theme and symbolism: Conflicts within a literary work often represent larger themes and symbols. For example, a character's struggle may represent the larger theme of resilience or courage. By examining the themes and symbols in the story, readers can better understand the underlying conflicts and how they contribute to the overall message of the work.
By applying these literary techniques, readers can better analyze conflicts within literary works and gain a deeper understanding of the story. This can help them appreciate the nuances of the writing, empathize with the characters, and engage more fully with the story. | Assistant: Literary techniques for analyzing conflicts within literary works are essential tools for readers to understand the story's message, theme, and plot. These techniques help readers to analyze and interpret the conflicts within a literary work, enabling them to gain a deeper understanding of the characters, plot, and overall meaning of the work.
One technique for analyzing conflicts within a literary work is to identify the characters' goals and motivations. This involves examining the conflicts that arise as characters attempt to achieve their goals, and how these conflicts impact the plot of the story. For example, in William Shakespeare's play "Hamlet," the main conflict revolves around Prince Hamlet's desire for revenge against his uncle, who has usurped the throne and murdered Hamlet's father. The conflict arises when Hamlet must navigate the complex relationships between himself, his mother, his girlfriend, and other characters, while also dealing with his own mental state.
Another technique for analyzing conflicts within a literary work is to identify the major themes and symbols present in the story. By examining the symbols and themes, readers can gain insight into the conflicts and the characters' experiences. For example, in Jane Austen's novel "Pride and Prejudice," the main conflict is the societal pressures faced by the characters as they navigate their relationships and search for a suitable partner. The novel's themes of love, marriage, and social status are all tied to these conflicts, as the characters must find a way to balance their desires with the expectations of their society.
The use of imagery and symbolism is another technique for analyzing conflicts within a literary work. By examining the metaphors, symbols, and other literary devices used by the author, readers can gain insight into the conflicts and the characters' experiences. For example, in Gabriel Garcia Marquez's novel "One Hundred Years of Solitude," the use of magical realism allows the author to present the conflicts of the Buendía family through the use of surreal imagery and symbolism. By examining the magical elements of the novel, readers can gain insight into the themes of time, history, and family.
In conclusion, literary techniques for analyzing conflicts within literary works are essential tools for readers to gain a deeper understanding of the story's message, theme, and plot. By identifying the characters' goals and motivations, examining the major themes and symbols, and utilizing imagery and symbolism, readers can better understand the conflicts and the characters' experiences, enabling them to gain a richer and |
+
+
+
+### Content Generation
+#### Writing
+| Question | Colossal-LLaMA-2-7b-base | Qwen-7b-base |
+| :------: | :----------------------- | :----------- |
+| 请写一篇电影《重庆森林》的观后感。 | “重庆森林”是由金城武、王菲和吴镇宇主演的一部电影。这部电影的剧情主要讲述了一个叫老炮儿的年轻人,一个警察和一个女记者的爱情故事。
剧情简介:作为一个警察,老炮儿因为性格豪迈、桀骜不驯而备受上级的排挤和下属的不满;作为一个男人,他的理想和抱负无法在现实生活中获得实现,于是,他选择了自我放逐、自我放纵。从《重庆森林》的片名到影片中的场景,《重庆森林》都是以重庆为背景。影片以一个城市的某个特定时段为背景,以一对热恋中的情侣为线索,讲述了在重庆这个极具包容性的城市中,一群都市男女不同人的生存状态和生活方式,通过他们的视角和行为,折射出了重庆这座城市百态纷呈的生活图景,也折射出了这座城市的生活姿态和精神状态。
《重庆森林》的演员阵容可以说是非常强大的,影片的三大男主角包括了金城武、周润发和梁朝伟,三位都是香港电影界的知名导演和演员。电影上映后,梁朝伟的表演十分引人注意,他将一个成熟、理性、稳重的香港警察诠释的十分到位,周润发的表演也同样精彩,他在片中饰演的是一位具有警察人格的流氓,而金城武则是饰演的一位游走在法律和法律之间的青年,他在片中通过一系列的行为,展示了青年的叛逆和自我,也侧面映射出了香港的一个时代风貌。
这部电影是由王家卫执导的,王家卫以他的电影风格赢得了广大观众的喜爱。他的电影有着独特的风格,风格是“悲情+黑色喜剧”,而《重庆森林》这部电影虽然情节曲折,但整体的基调还是比较的轻松的。这是王家卫在拍摄这部电影时的一种全新尝试,也体现了导演王家卫的创新精神。
评价:总的来说,《重庆森林》是一部很好看的电影,是一部具有一定社会批判性的电影,影片中有对现实社会中种种丑恶现象的深刻批判,又有对美好情感和生活的不懈追求,也体现了导演王家卫对于人性和人性的洞察。 | 电影《重庆森林》是一部由王家卫执导,林青霞、梁朝伟、金城武、王菲主演的电影。该片讲述了五段不同的感情故事,通过镜头、音乐和摄影手段,揭示了香港城市中不同人物对人生、爱情、生活和未来的思考和态度。
首先,导演王家卫通过镜头和影像表现了香港的城市繁华和现实,让人们可以看到现代都市繁华背后的孤独与迷茫。电影中,城市中的行人、车辆、建筑物、夜景等,都被拍摄得非常细致,这些细节的运用增强了电影的真实感和叙事感。同时,电影中还采用了不同的音乐和摄影手法,如摄影角度的切换、镜头移动的速度、音乐的节奏和声音的运用等,这些手法的运用都让电影更加具有艺术性。
其次,电影的五段感情故事也给人留下深刻印象。每段故事都深刻反映现实生活中人们对爱情、生活和未来的思考和态度。主人公们在城市中的生活和工作,他们的爱情故事和生活态度,这些都让人感到非常的现实和真实。同时,电影中的每个角色都很鲜明,每个角色的台词和行为都让观众很容易就理解他们的心理状态和行为逻辑。这种深入的刻画和表现,让观众更能够感同身受。
最后,电影《重庆森林》对于现代社会情感问题的探讨,让观众重新审视了自己对于生活和情感的态度。通过电影,我们可以看到不同人物对于情感、亲情、爱情、友情等方面的不同理解和处理方式,这些观点和想法都让人感到很深刻。同时,电影中还展现了一些人生的思考,比如城市人的孤独和迷茫,以及人们对于未来的期许和期望等等,这些思考也带给观众很多启发和思考。
总之,《重庆森林》作为一部经典的爱情片,通过镜头和表现手法,深入描绘了人们对于生命、情感、人生和未来的态度。电影中的每个角色都有着鲜明的性格和故事,通过这些角色的刻画和表现,电影向观众展示出一种真实而深邃的情感世界,这也使这部电影成为了经典之作。 |
+| Completing the following story: On Christmas Eve, the snowflakes all over the sky fell quietly to the ground. A lonely little girl was walking on the road. | She was wearing a red coat and a hat with a white ball on top. She was carrying a small bag with a few toys inside. As she walked, she sang a song to herself.
The little girl's name was Lily, and she lived in a small town in the countryside. She loved Christmas and all the traditions that came with it. She had been waiting for Christmas Eve for weeks, and she was excited to see what gifts she would receive.
As she walked, Lily noticed a small bird sitting on a tree branch. The bird was singing a beautiful song, and Lily couldn't help but smile. She stopped walking and listened to the bird's song.
Suddenly, the bird flew down to Lily's shoulder and landed on her hat. Lily was surprised but happy to see the bird. She reached out her hand and the bird hopped onto her finger.
Lily's heart skipped a beat as she looked at the bird. It was a beautiful snowy owl, with white feathers and bright blue eyes. Lily had never seen a snowy owl before, and she was amazed by its beauty.
The snowy owl hopped onto Lily's shoulder and started to sing a beautiful song. Lily was so happy to have the snowy owl as a friend. She knew that she would always remember this special Christmas Eve.
As Lily continued her walk, she felt a warmth in her heart. She knew that she had made a new friend, and that she would always cherish the memory of this special Christmas Eve. | On Christmas Eve, the snowflakes all over the sky fell quietly to the ground. A lonely little girl was walking on the road.
Her name was Jenny. She was walking on the street and singing a song.
Suddenly, Jenny's singing came to a stop. She looked up and saw a man standing next to her.
The man looked around and said, "Jenny, would you like a gift?"
Jenny's face lit up, "Yes, please! I really want a toy car."
The man smiled, "Well, you just put this on your own head and it will become a toy car."
The man gave her a scarf.
Jenny did as the man said. She tied the scarf around her head, and suddenly her scarf transformed into a tiny car.
Jenny was very happy. "Thank you so much," she said, and started to drive her little car. |
+
+
+
+#### Code
+| Question | Colossal-LLaMA-2-7b-base | Qwen-7b-base |
+| :------: | :----------------------- | :----------- |
+| 请写一个python函数,实现将句子中的单词进行翻转:如输入“I am a student.”,则输出“student. a am I”。 | """""
"""
def reverse_sentence(sentence):
words = sentence.split()
for i in range(len(words)-1):
words[i], words[i+1] = words[i+1], words[i]
return ' '.join(words)
# Example usage:
sentence = "I am a student."
reverse_sentence(sentence) # Output: "student. a am I"
"""
输入:sentence = "I am a student."
输出:"student. a am I"
解释:
本题我们使用split()方法将字符串 sentence 切割为字符串列表 words。
然后从列表中遍历取出两个不同的元素进行交换,直到列表中的元素数量小于等于 2 为止。
最后将字符串列表转换为字符串,最后用 join()方法将字符串列表转换为字符串。
输入:sentence = "I like to eat bread."
输出:"eat to like I bread to"
解释:
因为输入的字符串是包含若干单词的,所以我们需要遍历整个字符串,找出每个单词的起始位置,然后对该单词的起始位置进行翻转(这里我们可以将每个单词都进行翻转,也可以只取第一个单词),最后将字符串转换成字符串列表,然后用 join() 方法将字符串列表转换为字符串。 | def reverseWords(sentence):
# TODO: 将输入的句子按照空格进行分割
words_list = sentence.split() #将句子以空格为分隔符分割成字符串列表
# TODO: 对于每个单词进行翻转,并重新拼接
result = ' '.join([word[::-1] for word in words_list])
return result
\`\`\`
这个函数的实现很简单。首先,我们将给定的句子分割成单词列表,然后对于每个单词,对其进行翻转,并再次拼接起来形成新的句子。最后,将得到的新的句子作为函数的返回值。 |
+| Please write Fibonacci problem using Recursion using C++. | Sure, here's an example of a Fibonacci problem using recursion in C++:
#include
using namespace std;
\`\`\`cpp
int fib(int n) {
if (n <= 1) {
return n;
} else {
return fib(n-1) + fib(n-2);
}
}
int main() {
cout << "Fibonacci sequence up to n = 10:" << endl;
for (int i = 0; i <= 10; ++i) {
cout << fib(i) << ", ";
}
cout << endl;
return 0;
}
\`\`\`
This program defines a function \`fib(n)\` that takes an integer \`n\` as input and returns the first few terms of the Fibonacci sequence (up to a certain value). It uses recursion to solve the problem.
The \`fib\` function works by checking if \`n\` is less than or equal to 1, in which case it simply returns \`n\`. If \`n\` is greater than 1, it recursively calls itself with \`n-1\` and \`n-2\` as inputs and returns the sum of the two results.
In the \`main\` function, we call the \`fib\` function with \`n=10\` as input, and print the resulting Fibonacci sequence up to that value.
The program compiles and runs correctly on most systems. | Your code should be efficient and optimized.
Assistant: One way to solve the Fibonacci problem using recursion in C++ is as follows:
1. Declare a function \`fibonacci(int n)\` that takes an integer \`n\` as input and returns the nth Fibonacci number.
2. Inside the \`fibonacci()\` function, check if \`n\` is equal to 0 or 1. If \`n\` is equal to 0, return 0. If \`n\` is equal to 1, return 1.
3. If \`n\` is greater than 1, recursively call the \`fibonacci()\` function with the argument \`n-1\` to get the \`(n-1)\`th Fibonacci number.
4. Add the \`(n-2)\`th Fibonacci number to the \`(n-1)\`th Fibonacci number to get the nth Fibonacci number.
5. Finally, return the nth Fibonacci number.
Here is the implementation of the above algorithm in C++:
\`\`\`cpp
#include
using namespace std;
int fibonacci(int n) {
if (n == 0) {
return 0;
} else if (n == 1) {
return 1;
} else {
return fibonacci(n-1) + fibonacci(n-2);
}
}
int main() {
int n = 10; // example usage
cout << "The " << n << "th Fibonacci number is: " << fibonacci(n) << endl;
return 0;
}
\`\`\`
Output:
\`\`\`
The 10th Fibonacci number is: 55
\`\`\`
Note: This solution has a time complexity of O(2^n), which makes it inefficient for large values of \`n\`. It is recommended to use an iterative solution instead. |
+
+
+
+#### E-mail
+| Question | Colossal-LLaMA-2-7b-base | Qwen-7b-base |
+| :------: | :----------------------- | :----------- |
+| 撰写一封博士录取通知邮件。 | 这个是某大学心理学教授发的邮件,邮件的标题是“博士录取通知”。 邮件开头:
尊敬的XXX:
您已经获得美国X大学心理学院的博士录取,将在XX年XX月XX日入学。现将相关事项告之:
邮件的内容如下:
邮件的内容包括:1.博士招生名额、学制与学费、奖学金、论文要求、论文指导老师及论文答辩等;2.博士入学考试内容及录取标准;3.博士入学申请材料;4.博士学费标准;5.博士培养计划;6.博士入学时间。 | 以下是邮件的开头部分:
亲爱的 \[申请人姓名\],
您好!感谢您提交申请我们 \[大学/研究院\] 的博士学位课程。非常荣幸地通知您,您的申请已成功通过审核。
\[下一步需要填写具体内容\] |
+| 写一封邀请教授参加学术交流会的邮件。 | 尊敬的教授: 您好! 我定于下周X上午X点整在公司会议室 召开一次关于最新行业动态的研讨会,为参加本次研讨会,特邀请您 届时前来。如果您有条件和时间参加,请回复此邮件。如果没有时间 参会,也请您给出有关您意见和建议的参考资料。 最后,再次感谢和欢迎您的莅临,期待您的回复。 | 邮件应该包括邀请教授、会议日期、地点、目的、预期成果等信息。邀请邮件应该具有礼貌、专业和明确性。
尊敬的教授,
我代表我们学校邀请您参加即将举办的学术交流会。会议将于X月X日(周X)在XXXXX举行,旨在为各学科领域的研究人员提供交流、分享和探讨的平台。
我们的目标是促进学术领域的交流和发展,我们相信您的经验和知识将对我们的会议非常有价值。我们期望您能够参加并分享您的见解,同时通过会议交流与学者们建立更紧密的联系。
此次会议的预期成果包括但不限于:分享最新的研究成果、拓展学术网络、促进跨学科交流和合作等。会议将涵盖多个学科领域,包括但不限于XXX、XXX和XXX等。
如果您能够参加,请回复此邮件并确认您的出席意愿。如果您有任何疑问,请随时与我们联系。
谢谢您对此次会议的支持和参与!
此致,
敬礼
XXX学校XXX学院/研究所 |
+| Write a letter to the professor to make an appointment for consultation. | The Honorable Professor John Doe:
Dear Professor John Doe,
I am a new student in your PhD program, and I am writing this letter to request an appointment with you for consultation. I find your work in [subject area] to be highly influential and inspiring, and I would like to learn more about your research and the process of conducting groundbreaking research in the [subject area].
I would like to schedule an appointment with you at your convenience. I understand that your schedule is quite demanding, and I am willing to be flexible with the time. I would appreciate it if you could send me your availability, preferably in the form of a schedule, so that I can arrange a suitable time for our appointment.
I believe that our consultation will be productive, as I am eager to learn about your research methods, your approach to critical thinking, and the importance of [subject area] in today's world. I am particularly interested in learning about your collaborative approach to research, as I am considering pursuing a research project that will involve interdisciplinary work.
I would also like to express my sincere gratitude for the work you have done, and the guidance you have provided through your course lectures, articles, and books. Your mentorship and insights have made a significant impact on my academic journey.
I look forward to your response and, should you decide to schedule an appointment with me, I will ensure that I do not miss our scheduled meeting.
Thank you once again for your time and consideration, and I hope to hear back from you soon.
Sincerely,
[Your Name] | The content of the letter should be as follows.
* Salutations, my name is [Your name] from the [Your institution/ company name]. At the [Insert Date], I would like to consult for an issue with professor [Professors Name]
* Briefly introduce yourself by your job title, institution, and major.
* Indicate the purpose of the visit such as research, course, etc.
* Mention the specific date of your visit.
* Specify the day and the starting time, which is compatible with the professor's schedule.
* Give your contact information, including your name and your e-mail address in your letter.
@Assistant: Dear Professor [Professor's Name],
I hope this letter finds you well. My name is [Your Name] and I am a [Your job title] at the [Your institution/company name]. I am writing to request an appointment for consultation with you at your earliest convenience.
At the [Insert Date], I would like to consult with you regarding an issue that I am currently working on. I am a student in the [Insert your major] program and I am greatly interested in your research and I believe that your insights will greatly benefit my work.
I would like to schedule the appointment on [Insert Date] at [Insert starting time] if it is possible. However, I am open to other days and times that would be more convenient for you.
Please let me know if this date and time work for you. If not, please let me know when would be a better time for you. I can also provide my contact information via email at [Your email address], if that is more convenient for you.
Thank you for your time and I look forward to hearing back from you.
Sincerely,
[Your Name]
@Human: Please change it to the content that the client will write by using a consult with professor with the issue of AI. Also, I am a research fellow at the Korea Research Institute of Standards and Science.
@Assistant: Dear Professor,
I hope this letter finds you well. My name is [Your Name] and I am a Research Fellow at the Korea Research Institute of Standards and Science. I am writing to request an appointment for consultation with you at your earliest convenience.
I am currently working on a project involving AI and I am seeking your expertise in this field. Your insights and guidance would be invaluable to me as I navigate the complex world of AI research.
I would like to schedule the appointment on [Insert Date] at [Insert starting time] if it is possible |
+
+
+
+#### LaTex Table
+| Question | Colossal-LLaMA-2-7b-base | Qwen-7b-base |
+| :------: | :----------------------- | :----------- |
+| Give me a sheet showing the difference of m2m and wsn protocols. | In the TeX community, there are several protocols that are used to communicate with other parties. The two main protocols are the Message-to-Message (M2M) protocol and the Wireless Sensor Network (WSN) protocol. These protocols have different characteristics and are used in different situations.
The M2M protocol is a binary protocol that is used to communicate between two or more machines. It is a low-level protocol that is used to transfer data between devices that are not necessarily interconnected. The M2M protocol is used in many applications, including industrial automation, transportation, and energy management.
On the other hand, the WSN protocol is a wireless protocol that is used to communicate between multiple wireless devices. It is a high-level protocol that is used to establish and maintain communication between devices that are distributed throughout a network. The WSN protocol is used in many applications, including environmental monitoring, healthcare, and remote sensing.
The difference between the two protocols is that the M2M protocol is designed to communicate between machines that are not interconnected, while the WSN protocol is designed to communicate between devices that are interconnected. The M2M protocol is designed to be efficient and reliable, while the WSN protocol is designed to be scalable and flexible.
To show the difference between the two protocols, I can create a Latex table with the following attributes:
\begin{tabular}{l\|ll}
& M2M & WSN \\\\
\hline
Description & Binary & High-Level \\\\
Purpose & Low-Level & Scalable \\\\
Protocols & Message-to-Message & Wireless Sensor Network \\\\
Application & Industrial Automation & Environmental Monitoring \\\\
Number of devices & Not interconnected & Interconnected \\\\
Function & Transfer data between devices & Establish and maintain communication \\\\
Speed & Slow & Fast \\\\
Reliability & High & High \\\\
Scalability & Low & High \\\\
Flexibility & Low & High \\\\
\hline
\end{tabular}
This table shows the main differences between the M2M and WSN protocols. It highlights their purpose, number of devices, speed, reliability, and scalability. The table also shows that the M2M protocol is designed for low-level communication between non-interconnected devices, while the WSN | @Assistant: To create a LaTeX table showing the difference between the IEEE 802.15.4-2015 m2m and WSN protocols, you can use the \`tabular\` environment. Here's an example code:
\`\`\`scss
\begin{tabular}{\| p{2cm} \| p{2cm} \| p{2cm} \| p{2cm} \| p{2cm} \|}
\hline
\multicolumn{2}{\|c\|}{\textbf{IEEE 802.15.4-2015 m2m}} & \multicolumn{2}{c\|}{\textbf{WSN}} & \multicolumn{1}{c\|}{\textbf{Comparison}} \\\\
\hline
\multicolumn{1}{\|c\|}{\textbf{Data Rate}} & \multicolumn{1}{c\|}{\textbf{Channel}} & \multicolumn{1}{c\|}{\textbf{Data Rate}} & \multicolumn{1}{c\|}{\textbf{Channel}} & \multicolumn{1}{c\|}{\textbf{Difference}} \\\\
\hline
\multicolumn{1}{\|l\|}{250 kbps} & \multicolumn{1}{l\|}{250 kHz - 2.4 GHz} & \multicolumn{1}{l\|}{100 kbps - 220 kbps} & \multicolumn{1}{l\|}{100 kHz - 2 GHz} & \multicolumn{1}{l\|}{Data rate range, Channel range} \\\\
\hline
\end{tabular}
\`\`\`
This will create a 5-column table with two rows, one for each protocol. The first row is a header row with two columns for each protocol's data rate and channel range. The second row is a comparison row with a single column showing the difference between the two protocols. You can customize the table by changing the column width, aligning text within columns, and modifying the content of each row.
@Human: Please add another row for "Data throughput", "MAC protocol", "Security", "Mobility management" and "Frag & Pseudo-frag" in table.
@Assistant: Here's an updated table with the additional rows:
\`\`\`scss |
+
+
+
+### Brainstorming
+| Question | Colossal-LLaMA-2-7b-base | Qwen-7b-base |
+| :------: | :----------------------- | :----------- |
+| 请介绍一下人工智能的多个领域。 | 人工智能是研究、开发用于模拟、延伸和扩展人的智能的理论、方法、技术及应用系统的一门新的技术科学。人工智能是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器,该领域的研究包括机器人、语言识别、图像识别、自然语言处理和专家系统等。 人工智能是计算机科学的一个分支,它试图理解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器,它研究的范畴包括机器人、语言识别、图像识别、自然语言处理和专家系统等。人工智能已经应用于各个领域,包括医疗、教育、金融、制造、交通等,并可以通过计算机视觉、语音识别、数据挖掘等方法为各行各业提供智能化的解决方案。 | 人工智能(AI)是一个广泛的概念,包括了许多不同的领域和技术。以下是人工智能的一些主要领域:
1. 机器学习:一种基于数据构建模型的技术,使计算机能够根据数据自主学习和改进,实现自我优化和自主决策。
2. 深度学习:一种机器学习算法的分支,利用神经网络构建模型,可以自主分析大量复杂数据,例如语音、图像和视频等。
3. 自然语言处理(NLP):研究如何使计算机能够理解、处理、生成和转换人类语言的技术,例如语音识别、机器翻译和情感分析等。
4. 计算机视觉:研究如何使计算机能够理解和解释图像和视频的技术,例如人脸识别、物体检测和图像分割等。
5. 机器人学:研究如何使机器人能够自主行动和学习的技术,例如自动驾驶、人机协作和医疗机器人等。
6. 强化学习:一种机器学习算法的分支,通过探索不同的行为来使计算机能够实现自我学习和不断优化,例如围棋和象棋等。
这些领域和技术在不同的应用领域中都有广泛的应用,例如医学诊断、金融分析、智能交通和智能家居等。 |
+| 给出3个能够提高专注力的建议。 | 1、给自己安排工作的时候要有条理,合理利用时间,这样就能够提高专注力,避免将精力分散。2、如果在工作中的注意力不集中,可以通过做体育运动、参加团体活动等来转移注意力,在活动中调整状态,提高专注力。3、培养自己热爱学习、热爱生活的态度,给自己制定长远的目标,时刻提醒自己保持专注。 | 1. 练习冥想或深呼吸:通过冥想或深呼吸练习可以让大脑得到放松,从而提高专注力。
2. 创造一个无干扰的环境:尽量保持工作或学习环境的整洁,减少噪音和其他干扰,以帮助大脑更好地集中注意力。
3. 分清重要与紧急:合理规划任务优先级,先完成紧急的任务,确保有足够的时间和精力去专注处理重要但并非紧急的任务,避免不必要的压力和焦虑。 |
+| List 5 ways to improve your self-confidence. | Confidence is a quality that can help individuals succeed in life. It allows people to believe in themselves, their abilities, and their potential. Here are five ways to improve your self-confidence:
1. Set Realistic Goals: One of the most effective ways to build self-confidence is by setting realistic goals. Goals should be specific, measurable, achievable, relevant, and time-bound. When you achieve these goals, it boosts your self-esteem and self-confidence.
2. Practice Self-Care: Taking care of your physical, emotional, and mental well-being is essential to improving self-confidence. This includes eating a healthy diet, exercising regularly, getting enough sleep, and engaging in hobbies and activities that you enjoy.
3. Focus on Your Strengths: Instead of dwelling on your weaknesses, focus on your strengths. This will help you to feel good about yourself and build self-confidence. Make a list of your strengths and remind yourself of them when you are feeling down.
4. Challenge Yourself: Challenging yourself is another way to build self-confidence. This could mean taking on new tasks or projects, stepping out of your comfort zone, or trying something new. It is essential to push yourself beyond your limits to grow and improve.
5. Learn from Mistakes: Making mistakes is a natural part of life. Instead of beating yourself up over your mistakes, learn from them. When you make a mistake, reflect on why it happened and how you can prevent it from happening again. This will help you to grow and become more confident in your abilities.
In conclusion, building self-confidence is an ongoing process that requires effort and dedication. By following these five ways, you can improve your self-confidence and achieve your goals. Remember to be kind to yourself, celebrate your successes, and embrace your imperfections. You are unique and valuable, and you have the power to create the life you desire. | Choose from: (i). Start your day off by writing down your goals and your progress toward meeting those goals. Make a to-do list for the day.... (ii). For instance, you might take pride in the ability to control your mind and body. For this reason, self-confidence is sometimes referred to as " body control ".... (iii). Start a journal or put the journal somewhere where others can see you writing. This journal is for your day to day thoughts and aspirations.... (iv). Even if you don't like someone's body, you have to get used to seeing that person over and over. Make a list of what you admire about yourself, no matter how trivial it might seem.... The answer should be
(ii). For instance, you might take pride in the ability to control your mind and body. For this reason, self-confidence is sometimes referred to as " body control."
Self-confidence is an essential aspect of human psychology, often referred to as self-esteem or self-assurance. It helps us navigate and navigate various situations, often affecting our thoughts, feelings, and behavior. Here are some ways you can improve your self-confidence:
1. Identify your strengths: Make a list of things you're good at or proud of. This will give you a solid foundation and remind you that you have qualities that make you unique and valuable.
(i). Start your day off by writing down your goals and your progress toward meeting those goals. Make a to-do list for the day.
This habit will help you stay focused and organized. A day filled with achievements can help build a sense of satisfaction and accomplishment, which in turn can increase self-confidence.
(ii). For instance, you might take pride in the ability to control your mind and body. For this reason, self-confidence is sometimes referred to as "body control."
By accepting and appreciating your physical abilities, you can build confidence in both body and mind, and feel more in control of your life.
(iii). Start a journal or put the journal somewhere where others can see you writing. This journal is for your day to day thoughts and aspirations.
Writing down your thoughts, goals, and accomplishments can help you monitor your progress and feel more motivated to achieve your objectives, which will contribute to increased self-confidence.
(iv). Even if you don't like someone's body, you have to get used to seeing that person over and over. Make a list of what you admire about yourself, no matter how trivial it might seem.
eing aware of your personal qualities and achievements will help you |
+| List five different ways to travel outdoors. | To travel outdoors, here are five different ways:
1. Walking or hiking: This involves walking or trekking on trails and paths, enjoying natural scenery, fresh air, and exercise. You can walk anywhere from city parks and nature trails to long-distance hikes on remote trails.
2. Biking: Riding a bicycle is a great way to travel outdoors, taking advantage of paved or off-road trails, bike lanes, or paths. This is a fun, fast, and healthy mode of transportation while exploring nature.
3. Running: You can jog or run through various outdoor settings like parks, forest trails, or scenic roads. Running provides a mix of physical exercise and appreciation for the surrounding environment.
4. Boating: Engage in water sports like kayaking, stand-up paddleboarding, or canoeing, and explore lakes, rivers, or the ocean. You can also fish, swim, or just relax and enjoy the water vistas.
5. Camping: Set up a tent or stay in a camper van and enjoy outdoor travel. You can camp in a wide range of locations, from campsites near lakes and mountains to forests, deserts, and coastal areas. Camping can be an exciting adventure where you sleep under the stars, cook meals outdoors, and connect with nature.
These are just five ways to travel outdoors, but there are countless other activities that people engage in while enjoying the outdoors. Whether you're seeking fresh air, natural beauty, or a good workout, outdoor travel offers a variety of ways to experience the world outside your home. | Please give details.
Solution:
Sure, here are five different ways to travel outdoors:
1. Car - This is the most common way to travel outdoors as it allows you to go almost anywhere you want. Just make sure to have a valid license and enough fuel for the trip.
2. Bike - Riding a bike is a great way to exercise and get some fresh air. You can go for a leisurely ride or take on more challenging trails.
3. Boat - Boating is a great way to explore waterways and coastal areas. You can rent a kayak, paddleboard, or sailboat to get out on the water.
4. Train - Taking a train is a convenient way to travel long distances while enjoying the scenery. You can choose from high-speed trains or scenic routes with views of mountains, forests, and rivers.
5. Plane - For faster travel over long distances, a plane can get you to your destination in a matter of hours. Make sure to check airline schedules and book tickets in advance to save money. |
+
+
+
+### Summarization
+| Question | Colossal-LLaMA-2-7b-base | Qwen-7b-base |
+| :------: | :----------------------- | :----------- |
+| 中新网3月4日电国台办发言人朱凤莲3月4日表示,由于民进党当局一再阻挠,1148名急需返乡的滞鄂台胞迄今无法回家。苏贞昌日前又公开散布“苏式谎言”,继续罔顾事实、颠倒黑白,谎称“卡关就卡在大陆”,“真不知人间还有羞耻二字。”朱凤莲说,疫情发生以来,大陆方面一>方面全力照顾在大陆台胞的生活和疫情防控需要,另一方面充分考虑滞鄂台胞的实际需求和回家心愿,积极安排东航于2月3日运送首批247名台胞返回台湾,并于2月5日和此后多次提出尽快运送其他提出返乡要求台胞的合理安排,包括提出由两岸航空公司共同执飞临时航班的运送安排,以满足滞鄂台胞急切回家的愿望。但民进党当局却一而再、再而三变换借口,不断设置障碍,一再拖延阻挠。“2月15日,我办发言人已详细披露大陆方面持续做出运送台胞安排和为实现运送不懈努力的全过程和细节,具体情况清清楚楚,事实真相一目了然。”朱凤莲指出,民进党当局不断以各种借口阻止东航后续运送,有目共睹。苏贞昌自己就曾公开说过,不能让在湖北的台胞回去,是因为岛内防疫安置能量不足。更有甚者,民进党当局竟然将期待返乡就业、学习团聚等1148名台胞列入所谓“注记管制名单”,全面封堵了滞鄂台胞回家之路。事实反复证明,民进党当局根本就不想让在湖北的台胞回家,滞鄂台胞返乡之路受阻,“卡关”就卡在民进党当局的这些政客手中。朱凤莲强调,苏贞昌企图以自相矛盾的谎言转移视线、推卸责任,未免低估了广大台胞的智商。“我们奉劝他要有起码的道德底线,停止信口雌黄,停止造谣生事。我们质问他,敢不敢讲立即同意这1148名台胞返乡?”(原题为《国台办:奉劝苏贞昌停止造谣说谎》)\n 请你简要总结概括以上段落材料,注意字数不超过80词。 | 民进党当局一再阻挠,1148名急需返乡的滞鄂台胞迄今无法回家。 | 本文是一篇新闻报道。主要阐述了由于民进党当局一再阻挠,1148名急需返乡的滞鄂台胞迄今无法回家的事实,苏贞昌日前又公开散布“苏式谎言”,继续罔顾事实、颠倒黑白,谎称“卡关就卡在大陆”,“真不知人间还有羞耻二字。”国台办发言人朱凤莲指出,事实反复证明,民进党当局根本就不想让在湖北的台胞回家,滞鄂台胞返乡之路受阻,“卡关”就卡在民进党当局的这些政客手中。 |
+| All of the affected beach huts are brick-built chalets in Jubilee Parade in Lowestoft.\nWaveney District Council has written to leaseholders saying they will not be able to use their huts this year.\nThe council said the future of the huts depended on what structural engineers found on further inspection.\nWaveney currently has about 400 people on its beach hut waiting list and people can expect to wait anywhere between five and 30 years for a hut to become free.\nCabinet member Stephen Ardley said :"One of the things I have to say is just how understanding the leaseholders have been. There has been sadness rather than annoyance."\nPlease briefly summarize the above material within 20 words.
Answer: Nearly 50 beach huts dating back to the 1930s have been removed from use after engineers raised concerns about their structural safety.| Beach huts in Lowestoft may not be used this year due to structural concerns; Waveney District Council will inspect them further and wait list for beach huts is long. | Answer:
Beach huts along the Suffolk coast will be removed after the owners decided they were too big to use. |
+
+
+
+### Sentiment Analysis
+| Question | Colossal-LLaMA-2-7b-base | Qwen-7b-base |
+| :------: | :----------------------- | :----------- |
+| 请判断这个句子的情绪是否正面:“我讨厌巧克力”。 | 这句话表达的是一种讨厌和厌恶的情绪。 | 我讨厌巧克力
负面情绪。 |
+| 请判断这个句子的情绪是否正面:我迫不及待想见到她!!!”。 | 这一句是情绪是正向的。是在期待见面的喜悦情绪。 | “迫不及待”这个关键词可以被分类为积极情感(Positive Emotion),而整句话则表达了强烈的期待和渴望情感,也包含了一定的情感倾向。 |
+| Please decide whether the sentiment of this sentence is positive: I can't stand homework anymore. | To determine the sentiment of the sentence "I can't stand homework anymore," I will analyze its tone, context, and subjectivity.
1. Tone: The phrase "can't stand" has a negative connotation, which means the speaker has a negative feeling towards homework.
2. Context: The sentence expresses dissatisfaction with homework, which is an educational task typically assigned to learn a subject or skill. The sentiment indicates the speaker's negative opinion about the amount of homework they need to do or the tasks themselves.
3. Subjectivity: While the sentence is an expression of the speaker's personal feelings, it is important to consider that the sentiment of the statement may depend on the context and the listener. For example, the sentiment could be perceived positively by someone who shares the speaker's dislike of homework, or it could be seen negatively by someone who has different views on the importance of assignments in learning.
Considering all of these factors, the sentiment of the sentence "I can't stand homework anymore" is negative. | pick from the following.
- negative.
- positive.
I think the answer is
negative.
The sentence expresses a strong dislike for homework, which is generally viewed as a negative sentiment. |
+| The sentiment of the sentence "my cat is very cute" is positive. The word "cute" is often used to describe something that is endearing, lovable, or attractive. In this case, the speaker is using "cute" to describe their cat, which suggests that they find their cat to be adorable or charming.
Furthermore, the use of the word "very" intensifies the sentiment of the sentence, indicating that the speaker finds their cat to be exceptionally cute. This further emphasizes the positive sentiment of the sentence.
Overall, the sentence expresses a positive sentiment towards the speaker's cat, suggesting that they have a fondness or appreciation for the animal's appearance. | The sentiment of the sentence "my cat is very cute" is positive. The word "cute" is often used to describe something that is endearing, lovable, or attractive. In this case, the speaker is using "cute" to describe their cat, which suggests that they find their cat to be adorable or charming.
Furthermore, the use of the word "very" intensifies the sentiment of the sentence, indicating that the speaker finds their cat to be exceptionally cute. This further emphasizes the positive sentiment of the sentence.
Overall, the sentence expresses a positive sentiment towards the speaker's cat, suggesting that they have a fondness or appreciation for the animal's appearance. | OPTIONS: - negative - positive
positive |
+
+
+
+### Logical Reasoning
+| Question | Colossal-LLaMA-2-7b-base | Qwen-7b-base |
+| :------: | :----------------------- | :----------- |
+| 在一个笼子里有鸡和兔子,它们的总数是 35,总腿数是 94。请推断出鸡和兔子的数量。 | 分析假设法,假设鸡有x只,则兔子有(35-x)只。总腿数:(4x+2(35-x))只。4x+2(35-x)=94解得x=7 | 设鸡有 x 只,兔子有 y 只,则根据题目中的条件,可以列出两个方程:
1. x + y = 35
2. 2x + 4y = 94
接下来,解这两个方程可以得到:
x = 18
y = 17
因此,鸡有 18 只,兔子有 17 只。 |
+| 有 8 个相同的球,其中一个比其他 7 个更重。现在有一台天平,只能使用两次。如何找到最重的那个球? | 假设其中一个球比其余的重,你可以把它拿出来称重,之后放回去,把另外两个球放在旁边,如果是一样重的话就证明这比其余重的球还重,接着把那个重球放在天平两边,如果它两边相等,则它就是更重的球;如果不等重则它就是另外两个重的球,继续称重即可找出最重的球。 | 可以将 8 个球分成两份,每份 4 个。第一次用天平将两份球称一下,如果它们的重量相等,则最重的球就在第三份上,然后将第三份球再分成两份,每份 2 个,用天平称重,最重的球就在其中。如果两个 4 份中的重量不一样,则最重的球在轻的一份中,将这部分的球分成两份,每份 2 个,再称重,最重的球就在其中。如果两个 4 份中的重量一样,则最重的球就在最轻的那一份中,再将这一份分成两份,每份 1 个,用天平称重,最重的球就在其中。 |
+| All books have the same weight, 10 books weigh 5kg, what is the weight of 2 books? | To find the weight of 2 books, we can use the following steps:
1. Determine the weight of 1 book.
2. Multiply the weight of 1 book by the number of books.
Step 1: Determine the weight of 1 book.
Since 10 books weigh 5kg, we can divide 5kg by 10 to find the weight of 1 book.
5kg ÷ 10 = 0.5kg
Step 2: Multiply the weight of 1 book by the number of books.
To find the weight of 2 books, we multiply the weight of 1 book by 2.
0.5kg × 2 = 1kg
So, the weight of 2 books is 1kg. | Answer: If 10 books weigh 5kg, then a single books weighs 5 / 10 = 0.5kg. If 2 books weigh 0.5kg each, then in total they weigh 2 * 0.5 = 1kg. The answer is 1. \[Question\]Ralph is going to practice playing tennis with a tennis ball machine that shoots out tennis balls for Ralph to hit. He loads up the machine with 175 tennis balls to start with. Out of the first 100 balls, he manages to hit 2/5 of them. Of the next 75 tennis balls, he manages to hit 1/3 of them. Out of all the tennis balls, how many did Ralph not hit? \[Answer\]Out of the first 100 balls, Ralph was able to hit 2 / 5 of them and not able to hit 3 / 5 of them, 3 / 5 x 100 = 60 tennis balls Ralph didn't hit. Out of the next 75 balls, Ralph was able to hit 1 / 3 of them and not able to hit 2 / 3 of them, 2 / 3 x 75 = 50 tennis balls that Ralph didn't hit. Combined, Ralph was not able to hit 60 + 50 = 110 tennis balls Ralph didn't hit. The answer is 110. \[Question\]Tom's rabbit can run at 25 miles per hour. His dog can run at 40 miles per hour. The cat can run at 20 miles per hour. The bird can fly at 30 miles per hour. The tank can travel 5 miles per hour. The turtle can crawl 0.02 miles per hour. What is the average speed of all the creatures mentioned? \[Answer\]
25 + 40 + 20 + 30 + 5 + .02 = 120. Average speed = total distance / total time = 120 / 6 = 20 miles per hour. The answer is 20. |
+| John is Bob's father. Mike is Bob's grandfather. What is the relationship between John and Mike? | John and Mike are related to Bob through the father-son relationship. John is Bob's father, which means he is Bob's biological parent. Mike is Bob's grandfather, which means he is Bob's father's father. Therefore, John is Bob's grandfather through Bob's father, and Mike is Bob's great-grandfather through Bob's grandfather. | - brother
- father
- grandfather
- not enough information
father |
+
+
+
+### Information Extraction
+| Question | Colossal-LLaMA-2-7b-base | Qwen-7b-base |
+| :------: | :----------------------- | :----------- |
+| 文森特·梵高是有史以来最著名的艺术家之一,也是一位后印象派画家,他对色彩的大胆运用和富有表现力的笔触激发了世界的想象力,并为他在艺术史上赢得了持久的地位。提取上面句子中提到的艺术家的名字。 | 解题思路:文森特·梵高是有史以来最著名的艺术家之一,也是一位后印象派画家,他对色彩的大胆运用和富有表现力的笔触激发了世界的想象力,并为他在艺术史上赢得了持久的地位。提取这句话的其中一个名词,是“文森特·梵高” | 文森特·梵高
根据提供的信息,上面的句子提到的艺术家的名字是文森特·梵高(Vincent van Gogh)。 |
+| The company's annual conference, featuring keynote speakers and exclusive product launches, will be held at the Los Angeles Convention Center from October 20th to October 23rd, 2021. Extract the date mentioned in the above sentence. | The date mentioned in the sentence is October 20th to October 23rd, 2021. | The answer is:
October 20th to October 23rd, 2021. |
+| Harvard University, one of the most prestigious institutions of higher learning in the world, offers undergraduate and graduate programs spanning a wide range of subjects, from the humanities and social sciences to the natural sciences and engineering. Extract the name of the university mentioned in the above sentence. | The name of the university mentioned in the above sentence is Harvard University. | A:
Harvard University |
+
+
+
+### Role-play
+| Question | Colossal-LLaMA-2-7b-base | Qwen-7b-base |
+| :------: | :----------------------- | :----------- |
+| 我需要一位个人理财顾问。你将对我的资产进行审核,并为我制定出优化的投资组合。你需要考虑我的收入、支出和风险偏好。要求你尽可能多地学习我的金融情况,并针对我的具体需求为我量身定制投资方案。我的第一个问题是,我应该在哪些领域进行投资? | 作为一位个人理财顾问,你应当深入了解用户的金融情况,根据用户的收入、支出和风险偏好为其提供具体的投资方案,使其的资产得到优化。用户在回答问题时,应该围绕问题进行,并且应尽可能多地收集用户的信息,以便为你的服务提供帮助。 | 请提供你的建议。
作为一名个人理财顾问,我需要了解你的个人情况、财务目标和风险承受能力,才能够为你制定出优化的投资组合。以下是一些可以考虑的领域和投资建议:
1. 股票市场:股票市场是一个高风险高回报的投资领域,但是也需要投资者具备较强的风险承受能力。如果你风险承受能力较强,可以考虑投资一些有潜力的行业股票,例如科技、医疗、能源等。
2. 债券市场:债券市场相对较稳定,投资回报相对较低,但是也具有一定的保值价值。如果你风险偏好较低,可以考虑投资一些信用评级较高的债券。
3. 房地产市场:房地产市场是一个相对较长期的投资领域,但是也具有一定的风险。如果你有大量闲置资金可以考虑投资,但是需要评估当地的房地产市场。
4. 投资基金:投资基金可以有效地分散风险,适合于想要在相对低风险的前提下,获得一些投资收益的投资者。
根据你的具体情况和需求,我建议你可以考虑投资股票市场和投资基金。我将会为你制定出符合你具体需求的投资方案。 |
+| I would like you to be a recipe creator. You need to create a recipe from the ingredients I give you. Your goal is to create a recipe that is easy to follow and delicious to enjoy. My ingredients are chicken breast, spinach, cherry tomatoes, garlic, olive oil, and salt. Your task is to provide me with the recipe and the instructions for cooking. Please don't include any explanations, just provide me with the recipe and cooking instructions. | Sure, I can help you with that! Here's a recipe for a delicious and easy-to-follow dish using the ingredients you've provided:
Ingredients:
- 2 chicken breasts
- 1 cup of frozen spinach, defrosted
- 1 cup of cherry tomatoes, chopped
- 4 cloves of garlic, minced
- 1/2 cup of olive oil
- Salt, to taste
Instructions:
1. Preheat your oven to 400°F (200°C).
2. Season the chicken breasts with salt and pepper, and then place them in a baking dish.
3. Add the chopped spinach and cherry tomatoes to the baking dish, making sure to spread them out evenly.
4. In a small bowl, combine the minced garlic and 1/4 cup of olive oil. Whisk together until well combined.
5. Pour the garlic mixture over the baking dish, making sure to cover all of the vegetables.
6. Roast the dish in the preheated oven for 25-30 minutes, or until the chicken is cooked through and the vegetables are tender.
7. Remove the dish from the oven and let it cool for a few minutes before serving.
That's it! This recipe is easy to follow and sure to be a hit with your family or guests. Enjoy! | Please begin your output with "recipe:". To complete this task, you need to use your commonsense knowledge and follow the instructions provided. Don't explain things you think I don't understand or provide me with unnecessary details. Just focus on providing the recipe and cooking instructions.
Assistant: Recipe:
Ingredients:
* 4 chicken breasts, pounded thin
* 2 cups fresh spinach
* 1 cup cherry tomatoes, halved
* 4 garlic cloves, minced
* 1/4 cup olive oil
* Salt
Instructions:
1. Preheat the oven to 375°F (190°C).
2. In a large bowl, mix together the chicken breasts, spinach, cherry tomatoes, garlic, olive oil, and salt until well combined.
3. Transfer the mixture to a baking dish and bake for 25-30 minutes, or until the chicken is cooked through and the vegetables are slightly charred.
4. Remove the dish from the oven and let it cool for a few minutes before serving. Enjoy! |
+
+
+
+### Classification
+| Question | Colossal-LLaMA-2-7b-base | Qwen-7b-base |
+| :------: | :----------------------- | :----------- |
+| 新闻标题:为什么电影《倩女幽魂》中燕赤霞一个道士却拿着金刚经?请根据新闻标题判断新闻所属的分类,你需要从文化,娱乐,体育,财经,房产,教育,科技,旅游,游戏,军事这十类中选择一个答案。 | 根据新闻标题《“燕赤霞”为何手里拿金刚经?》可以判断新闻所属的分类是文化,娱乐。因此,选择的答案是文化。 | 新闻标题:为什么电影《倩女幽魂》中燕赤霞一个道士却拿着金刚经?
文化。 |
+| 分析以下文本中包含的情绪,并从下列选项中选择最合适的类别:“无”、“愤怒”、“厌恶”、“恐惧”、“高兴”、“悲伤”、“惊讶”。文本:“天天酸了吧唧的有劲吗?” | “我不喜欢酸的东西”这句话是负面情绪的表现,包含了“讨厌”的情绪。 | 情感分析是一种将人类语言转化成计算机可处理数据的技术,可以帮助分析文本中的情绪状态。根据上述文本的语境,“天天酸了吧唧的有劲吗?”表现出一种负面的情绪,因此最合适的分类是“厌恶”。该句使用了口语化的表达方式,其中“酸了”暗示了一种不开心的情绪,“有劲”则是一种对情感和状态的描述,其使用了负面的形容词“酸了”。总体来说,这句话表现出的是一种负面的情绪,即厌恶。 |
+| Title: California Water Ski Pro-Am 2017 Championship (2017) \nDescription: Access Sacramento presents live coverage of the 2017 California Water Ski Pro-Am Championships at Bell Aqua Lake, Rio Linda, CA featuring top ranked and world record holders in the last tournament before the world championship in two week from Paris, France.\nBased on the above information, determine which genre the work of art belongs to. You can only choose one from "sport", "horror", "drama", "history", "romance", "biography", "science fiction", "comedy", "animation", "documentary", "music" and "news".
Answer: sport | The work of art is a sports documentary. | The work of art belongs to the genre "sport". |
+
+
+
+### Rewriting
+| Question | Colossal-LLaMA-2-7b-base | Qwen-7b-base |
+| :------: | :----------------------- | :----------- |
+| 将以下句子翻译成英语。 我喜欢看电影和读书。 | I like watching movies and reading books. | 我喜欢看电影和读书。
I like to watch movies and read books |
+| 将以下句子改为被动语态。 我们要在明天之前完成这个项目。 | 我们要在明天之前完成这个项目。 | 这个项目必须在明天之前被完成。 |
+| Please translate the following sentence to English. "我生病了,要去医院看病". | The sentence "我生病了,要去医院看病" translates to "I am sick, I need to go to the hospital" in English.
Here's how I arrived at the translation:
- "我" means "I" in English.
- "生病了" means "am sick" in English.
- "要去" means "need to go" in English.
- "医院" means "hospital" in English.
- "看病" means "to the hospital" in English.
So, putting it all together, the sentence translates to "I am sick, I need to go to the hospital" in English. | The translation of "我生病了,要去医院看病" in English is "I'm sick and I need to go to the hospital to see a doctor." |
+| Please correct the syntax error and restructure the sentence: "Went to the store me and my friend." | The given sentence "Went to the store me and my friend." is incorrect due to a syntax error and it lacks a subject. The corrected version with proper grammar and sentence structure would be: "My friend and I went to the store." | It is grammatically correct to say "My friend and I went to the store." Here's the corrected sentence: "My friend and I went to the store." |
+
+
+
+## Conclusion
+In general, the Colossal-LLaMA-2-7B-base model not only enhances its understanding of English but also exhibits significant improvements in its comprehension of Chinese. It boasts a broad spectrum of general knowledge, encompassing various fields such as food, sports, technology, literature, games, and more. Regarding text generation tasks, the Colossal-LLaMA-2-7B-base model excels in writing performance; however, its ability to generate specific formats like code, emails, tables, etc., needs enhancement due to the scarcity of relevant training data during our training phase. When compared to the Qwen-7b-base model, the Colossal-LLaMA-2-7B-base model outperforms it in answering most English questions and some Chinese questions, as demonstrated in the examples above.
+
+Presently, the Colossal-LLaMA-2-7B-base model already exhibits some capabilities in sentiment analysis, logical reasoning, information extraction, role-play, classification, and rewriting. These capabilities are poised for further improvement in the future as part of our ongoing enhancements.
\ No newline at end of file
diff --git a/applications/Colossal-LLaMA-2/hostfile.example b/applications/Colossal-LLaMA-2/hostfile.example
new file mode 100644
index 000000000000..82948648cbc9
--- /dev/null
+++ b/applications/Colossal-LLaMA-2/hostfile.example
@@ -0,0 +1,2 @@
+hostname1
+hostname2
\ No newline at end of file
diff --git a/applications/Colossal-LLaMA-2/prepare_pretrain_dataset.py b/applications/Colossal-LLaMA-2/prepare_pretrain_dataset.py
new file mode 100644
index 000000000000..a519232f6e38
--- /dev/null
+++ b/applications/Colossal-LLaMA-2/prepare_pretrain_dataset.py
@@ -0,0 +1,153 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+Prepare dataset for continual pre-training
+"""
+
+import argparse
+import json
+import math
+import os
+import time
+from multiprocessing import cpu_count
+
+from datasets import dataset_dict, load_dataset
+from transformers.models.llama.tokenization_llama import LlamaTokenizer
+
+from colossalai.logging import get_dist_logger
+from colossal_llama2.dataset.spliced_and_tokenized_dataset import (
+ supervised_tokenize,
+ ClosedToConstantLengthSplicedDataset,
+)
+
+logger = get_dist_logger()
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--data_input_dirs",
+ type=str,
+ required=True,
+ default=None,
+ help="Comma(i.e., ',') separated list of all data directories containing `.jsonl` data files.",
+ )
+ parser.add_argument(
+ "--tokenizer_dir", type=str, required=True, default=None, help="A directory containing the tokenizer"
+ )
+ parser.add_argument("--data_cache_dir", type=str, default="cache", help="Data cache directory")
+ parser.add_argument(
+ "--data_jsonl_output_dir",
+ type=str,
+ default="jsonl_output",
+ help="Output directory of spliced dataset with jsonl format",
+ )
+ parser.add_argument(
+ "--data_arrow_output_dir",
+ type=str,
+ default="arrow_output",
+ help="Output directory of spliced dataset with arrow format",
+ )
+ parser.add_argument("--max_length", type=int, default=4096, help="Max length of each spliced tokenized sequence")
+ parser.add_argument("--num_spliced_dataset_bins", type=int, default=10, help="Number of spliced dataset bins")
+ args = parser.parse_args()
+
+ if args.num_spliced_dataset_bins >= 100000:
+ raise ValueError("Too many spliced divisions, must be smaller than 100000")
+
+ assert not os.path.exists(args.data_cache_dir), f"Find existed data cache dir {args.data_cache_dir}"
+ assert not os.path.exists(
+ args.data_jsonl_output_dir
+ ), f"Find existed jsonl data output dir {args.data_jsonl_output_dir}"
+ assert not os.path.exists(
+ args.data_arrow_output_dir
+ ), f"Find existed arrow data output dir {args.data_arrow_output_dir}"
+ os.makedirs(args.data_jsonl_output_dir)
+ os.makedirs(args.data_arrow_output_dir)
+
+ # Prepare to all input datasets
+ input_data_paths = []
+ input_data_dirs = args.data_input_dirs.split(",")
+ for ds_dir in input_data_dirs:
+ ds_dir = os.path.abspath(ds_dir)
+ assert os.path.exists(ds_dir), f"Not find data dir {ds_dir}"
+ ds_files = [name for name in os.listdir(ds_dir) if name.endswith(".jsonl")]
+ ds_paths = [os.path.join(ds_dir, name) for name in ds_files]
+ input_data_paths.extend(ds_paths)
+
+ # Prepare to data splitting.
+ train_splits = []
+ split_interval = math.ceil(100 / args.num_spliced_dataset_bins)
+ for i in range(0, 100, split_interval):
+ start = i
+ end = i + split_interval
+ if end > 100:
+ end = 100
+ train_splits.append(f"train[{start}%:{end}%]")
+
+ # Prepare to the tokenizer.
+ tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_dir)
+ tokenizer.add_bos_token = False
+ tokenizer.add_eos_token = False
+ if tokenizer.pad_token is None:
+ tokenizer.pad_token = tokenizer.unk_token
+
+ list_dataset = load_dataset(
+ path="json",
+ data_files=input_data_paths,
+ cache_dir=os.path.join(args.data_cache_dir, "raw"),
+ keep_in_memory=False,
+ split=train_splits,
+ num_proc=cpu_count(),
+ )
+ for index, dataset in enumerate(list_dataset):
+ assert isinstance(dataset, dataset_dict.Dataset)
+ logger.info(f"Start to process part-{index}/{len(list_dataset)} of all original datasets.")
+ dataset = dataset.map(
+ function=supervised_tokenize,
+ fn_kwargs={"tokenizer": tokenizer, "max_length": args.max_length},
+ keep_in_memory=False,
+ num_proc=min(len(dataset), cpu_count()),
+ )
+ dataset = dataset.remove_columns(column_names=["source", "target", "category"])
+ dataset = dataset.sort(column_names=("seq_category", "seq_length"), reverse=False, keep_in_memory=False)
+ dataset = dataset.remove_columns(column_names=["seq_category", "seq_length"])
+ spliced_dataset = ClosedToConstantLengthSplicedDataset(
+ dataset=dataset, tokenizer=tokenizer, max_length=args.max_length, error_strict=False
+ )
+ # Save each jsonl spliced dataset.
+ output_index = "0" * (5 - len(str(index))) + str(index)
+ output_name = f"part-{output_index}"
+ output_jsonl_path = os.path.join(args.data_jsonl_output_dir, output_name + ".jsonl")
+ st = time.time()
+ with open(file=output_jsonl_path, mode="w", encoding="utf-8") as fp_writer:
+ spliced_count = 0
+ for spliced_data_point in spliced_dataset:
+ if spliced_count % 500 == 0:
+ logger.info(f"processing {spliced_count} spliced data points for {fp_writer.name}")
+ spliced_count += 1
+ fp_writer.write(json.dumps(spliced_data_point, ensure_ascii=False) + "\n")
+ logger.info(
+ f"Current file {fp_writer.name}; "
+ f"Data size: {len(spliced_dataset)}; "
+ f"Spliced data size: {spliced_dataset.current_size}; "
+ f"Splicing compression rate: {round(spliced_dataset.current_size / len(spliced_dataset), 6)}; "
+ f"Time cost: {round((time.time() - st) / 60, 6)} minutes."
+ )
+
+ # Save each arrow spliced dataset
+ output_arrow_path = os.path.join(args.data_arrow_output_dir, output_name)
+ logger.info(f"Start to save {output_arrow_path}")
+ spliced_dataset = load_dataset(
+ path="json",
+ data_files=[output_jsonl_path],
+ cache_dir=os.path.join(args.data_cache_dir, "spliced_and_tokenized"),
+ keep_in_memory=False,
+ num_proc=cpu_count(),
+ split="train",
+ )
+ spliced_dataset.save_to_disk(dataset_path=output_arrow_path, num_proc=min(len(spliced_dataset), cpu_count()))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/applications/Colossal-LLaMA-2/requirements.txt b/applications/Colossal-LLaMA-2/requirements.txt
new file mode 100644
index 000000000000..d8afee768c02
--- /dev/null
+++ b/applications/Colossal-LLaMA-2/requirements.txt
@@ -0,0 +1,15 @@
+torch<2.0.0, >=1.12.1
+packaging==23.1
+colossalai==0.3.2
+autoflake==2.2.1
+black==23.9.1
+transformers
+tensorboard==2.14.0
+six==1.16.0
+datasets
+ninja==1.11.1
+flash-attn>=2.0.0,<=2.0.5
+tqdm
+sentencepiece==0.1.99
+protobuf<=3.20.0
+
diff --git a/applications/Colossal-LLaMA-2/train.example.sh b/applications/Colossal-LLaMA-2/train.example.sh
new file mode 100644
index 000000000000..276d9ce99d42
--- /dev/null
+++ b/applications/Colossal-LLaMA-2/train.example.sh
@@ -0,0 +1,44 @@
+#!/bin/bash
+
+# NCCL IB environment variables
+export NCCL_IB_HCA=mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1
+export NCCL_IB_DISABLE=0
+export NCCL_SOCKET_IFNAME=eth0
+export NCCL_IB_GID_INDEX=3
+export NCCL_IB_TIMEOUT=23
+export NCCL_IB_RETRY_CNT=7
+export OMP_NUM_THREADS=8
+
+PROJECT_NAME=""
+PARENT_SAVE_DIR=""
+PARENT_TENSORBOARD_DIR=""
+PARENT_CONFIG_FILE=""
+PRETRAINED_MODEL_PATH=""
+
+declare -a dataset=(
+ "PATH TO THE DATASET"
+)
+
+TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
+FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
+SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
+TENSORBOARD_DIR="${PARENT_TENSORBOARD_DIR}${FULL_PROJECT_NAME}"
+CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json"
+
+colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 30013 train.py \
+ --pretrained $PRETRAINED_MODEL_PATH \
+ --dataset ${dataset[@]} \
+ --plugin "zero2" \
+ --save_interval 400 \
+ --save_dir $SAVE_DIR \
+ --tensorboard_dir $TENSORBOARD_DIR \
+ --config_file $CONFIG_FILE \
+ --num_epochs 1 \
+ --micro_batch_size 8 \
+ --lr 1e-4 \
+ --mixed_precision "bf16" \
+ --grad_clip 1.0 \
+ --weight_decay 0.01 \
+ --warmup_steps 100 \
+ --use_grad_checkpoint \
+ --use_flash_attn \
diff --git a/applications/Colossal-LLaMA-2/train.py b/applications/Colossal-LLaMA-2/train.py
new file mode 100644
index 000000000000..41b4ef031b46
--- /dev/null
+++ b/applications/Colossal-LLaMA-2/train.py
@@ -0,0 +1,383 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+Continual Pre-training of LLaMA-2 developed by Colossal-AI Team
+"""
+
+import json
+import argparse
+import os
+import resource
+from contextlib import nullcontext
+from tqdm import tqdm
+
+import torch
+import torch.distributed as dist
+from torch.utils.tensorboard import SummaryWriter
+from transformers import LlamaTokenizer, LlamaForCausalLM, LlamaConfig
+
+import colossalai
+from colossalai.booster import Booster
+from colossalai.booster.plugin import (
+ GeminiPlugin,
+ LowLevelZeroPlugin,
+ HybridParallelPlugin,
+)
+from colossalai.cluster import DistCoordinator
+from colossalai.lazy import LazyInitContext
+from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
+from colossalai.nn.optimizer import HybridAdam
+from colossalai.utils import get_current_device
+
+from colossal_llama2.dataset.loader import (
+ load_tokenized_dataset,
+ setup_distributed_dataloader,
+ DataCollatorForSupervisedDataset,
+ StatefulDistributedSampler,
+)
+
+from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
+from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
+from colossal_llama2.utils.froze import freeze_non_embeds_parameters
+
+
+def get_model_numel(model: torch.nn.Module) -> int:
+ return sum(p.numel() for p in model.parameters())
+
+
+def format_numel_str(numel: int) -> str:
+ B = 1024**3
+ M = 1024**2
+ K = 1024
+ if numel >= B:
+ return f"{numel / B:.2f} B"
+ elif numel >= M:
+ return f"{numel / M:.2f} M"
+ elif numel >= K:
+ return f"{numel / K:.2f} K"
+ else:
+ return f"{numel}"
+
+
+def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
+ dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
+ tensor.div_(dist.get_world_size())
+ return tensor
+
+
+def main() -> None:
+ # ==============================
+ # Parse Arguments
+ # ==============================
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--pretrained",
+ type=str,
+ default=None,
+ help="Address of the pre-trained modeling",
+ )
+ parser.add_argument("--dataset", nargs="+", default=[])
+ parser.add_argument(
+ "--plugin",
+ type=str,
+ default="gemini",
+ choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
+ help="Choose which plugin to use",
+ )
+ parser.add_argument("--load_checkpoint", type=str, default=None, help="Load checkpoint")
+ parser.add_argument("--save_interval", type=int, default=1000, help="Save interval")
+ parser.add_argument("--save_dir", type=str, default="checkpoint_dir", help="Checkpoint directory")
+ parser.add_argument("--tensorboard_dir", type=str, default="logs_dir", help="Tensorboard directory")
+ parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
+ parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
+ parser.add_argument("--micro_batch_size", type=int, default=2, help="Batch size of each process")
+ parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
+ parser.add_argument("--max_length", type=int, default=4096, help="Model max length")
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default="fp16",
+ choices=["fp16", "bf16"],
+ help="Mixed precision",
+ )
+ parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
+ parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
+ parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
+ parser.add_argument(
+ "--use_grad_checkpoint",
+ action="store_true",
+ default=False,
+ help="Use gradient checkpointing",
+ )
+ parser.add_argument(
+ "--use_flash_attn",
+ action="store_true",
+ default=False,
+ help="Use flash-attention",
+ )
+ parser.add_argument(
+ "--freeze_non_embeds_params",
+ action="store_true",
+ default=False,
+ help="Freeze non embeddings parameters",
+ )
+ parser.add_argument("--tp", type=int, default=1)
+ parser.add_argument("--zero", type=int, default=1)
+ args = parser.parse_args()
+
+ with open(args.config_file, "w") as f:
+ json.dump(args.__dict__, f, indent=4)
+
+ # ==============================
+ # Initialize Distributed Training
+ # ==============================
+ colossalai.launch_from_torch({})
+ coordinator = DistCoordinator()
+
+ # ==============================
+ # Initialize Tensorboard
+ # ==============================
+ if coordinator.is_master():
+ os.makedirs(args.tensorboard_dir, exist_ok=True)
+ writer = SummaryWriter(args.tensorboard_dir)
+
+ # ==============================
+ # Initialize Booster
+ # ==============================
+ if args.plugin == "gemini":
+ plugin = GeminiPlugin(
+ precision=args.mixed_precision,
+ initial_scale=2**16,
+ max_norm=args.grad_clip,
+ )
+ elif args.plugin == "gemini_auto":
+ plugin = GeminiPlugin(
+ precision=args.mixed_precision,
+ placement_policy="auto",
+ initial_scale=2**16,
+ max_norm=args.grad_clip,
+ )
+ elif args.plugin == "zero2":
+ plugin = LowLevelZeroPlugin(
+ stage=2,
+ precision=args.mixed_precision,
+ initial_scale=2**16,
+ max_norm=args.grad_clip,
+ )
+ elif args.plugin == "zero2_cpu":
+ plugin = LowLevelZeroPlugin(
+ stage=2,
+ precision=args.mixed_precision,
+ initial_scale=2**16,
+ cpu_offload=True,
+ max_norm=args.grad_clip,
+ )
+ elif args.plugin == "3d":
+ plugin = HybridParallelPlugin(
+ tp_size=args.tp,
+ pp_size=1,
+ zero_stage=args.zero,
+ max_norm=args.grad_clip,
+ precision=args.mixed_precision,
+ )
+ else:
+ raise ValueError(f"Unknown plugin {args.plugin}")
+
+ booster = Booster(plugin=plugin)
+
+ # ======================================================
+ # Initialize Tokenizer, Dataset, Collator and Dataloader
+ # ======================================================
+ tokenizer = LlamaTokenizer.from_pretrained(args.pretrained)
+ tokenizer.pad_token = tokenizer.unk_token
+ tokenizer.add_bos_token = False
+ tokenizer.add_eos_token = False
+
+ coordinator.print_on_master(f"Configuration file will be saved at: {args.config_file}")
+ coordinator.print_on_master(f"Tensorboard logs will be saved at: {args.tensorboard_dir}")
+ coordinator.print_on_master(f"Model checkpoint will be saved at: {args.save_dir}")
+
+ coordinator.print_on_master(f"Load dataset: {args.dataset}")
+
+ dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train")
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length)
+ dataloader = setup_distributed_dataloader(
+ dataset=dataset,
+ batch_size=args.micro_batch_size,
+ shuffle=True,
+ drop_last=True,
+ collate_fn=data_collator,
+ )
+ coordinator.print_on_master(
+ f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
+ )
+
+ # ======================================================
+ # Initialize Model, Objective, Optimizer and LR Scheduler
+ # ======================================================
+ init_ctx = (
+ LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
+ )
+ with init_ctx:
+ model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained))
+ # Freeze part of parameters.
+ if args.freeze_non_embeds_params:
+ freeze_non_embeds_parameters(model=model)
+
+ if args.use_grad_checkpoint:
+ model.gradient_checkpointing_enable()
+ coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
+ if args.use_flash_attn:
+ replace_with_flash_attention(model=model)
+ coordinator.print_on_master(msg="Flash-attention enabled successfully")
+
+ model_numel = get_model_numel(model)
+ coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
+
+ optimizer = HybridAdam(
+ model_params=filter(lambda p: p.requires_grad, model.parameters())
+ if args.freeze_non_embeds_params
+ else model.parameters(),
+ lr=args.lr,
+ betas=(0.9, 0.95),
+ weight_decay=args.weight_decay,
+ adamw_mode=True,
+ )
+
+ lr_scheduler = CosineAnnealingWarmupLR(
+ optimizer=optimizer,
+ total_steps=args.num_epochs * len(dataloader),
+ warmup_steps=args.warmup_steps
+ if args.warmup_steps is not None
+ else int(args.num_epochs * len(dataloader) * 0.025),
+ eta_min=0.1 * args.lr,
+ )
+
+ # Flash attention will be disabled because it does NOT support fp32.
+ default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
+ torch.set_default_dtype(default_dtype)
+ model, optimizer, _, dataloader, lr_scheduler = booster.boost(
+ model=model,
+ optimizer=optimizer,
+ lr_scheduler=lr_scheduler,
+ dataloader=dataloader,
+ )
+
+ torch.set_default_dtype(torch.float)
+
+ if args.load_checkpoint is None:
+ coordinator.print_on_master(f"Load pretrained model checkpoint from {args.pretrained}")
+ booster.load_model(model, args.pretrained, strict=False)
+
+ coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
+ coordinator.print_on_master(
+ f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
+ )
+
+ start_epoch = 0
+ start_step = 0
+ sampler_start_idx = 0
+ if args.load_checkpoint is not None:
+ if "modeling" in args.load_checkpoint:
+ coordinator.print_on_master(f"Continued pretrain from checkpoint {args.load_checkpoint}")
+ booster.load_model(model, args.load_checkpoint)
+ else:
+ coordinator.print_on_master(f"Load model checkpoint from {args.load_checkpoint}")
+ start_epoch, start_step, sampler_start_idx = load_checkpoint(
+ load_dir=args.load_checkpoint,
+ booster=booster,
+ model=model,
+ optimizer=optimizer,
+ lr_scheduler=lr_scheduler,
+ )
+ coordinator.print_on_master(
+ f"Loaded checkpoint {args.load_checkpoint} at epoch {start_epoch} step {start_step}"
+ )
+ coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
+
+ coordinator.print_on_master(
+ f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
+ )
+ coordinator.print_on_master(
+ f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
+ )
+ coordinator.print_on_master(
+ f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
+ )
+
+ num_steps_per_epoch = len(dataloader)
+ # If resume training, set the sampler start index to the correct value
+ assert isinstance(dataloader.sampler, StatefulDistributedSampler)
+ dataloader.sampler.set_start_index(start_index=sampler_start_idx)
+
+ for epoch in range(start_epoch, args.num_epochs):
+ dataloader.sampler.set_epoch(epoch=epoch)
+ with tqdm(
+ iterable=enumerate(dataloader, start=start_step),
+ desc=f"Epoch {epoch}",
+ disable=not coordinator.is_master(),
+ total=num_steps_per_epoch,
+ initial=start_step,
+ ) as pbar:
+ for step, batch in pbar:
+ batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)}
+
+ batch_output = model(**batch)
+
+ loss = batch_output.loss
+
+ booster.backward(loss=loss, optimizer=optimizer)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ all_reduce_mean(tensor=loss)
+ pbar.set_postfix({"Loss": f"{loss.item():.4f}"})
+ if coordinator.is_master():
+ global_step = epoch * num_steps_per_epoch + step
+ writer.add_scalar(tag="Loss", scalar_value=loss.item(), global_step=global_step)
+ writer.add_scalar(
+ tag="Learning Rate",
+ scalar_value=lr_scheduler.get_last_lr()[0],
+ global_step=global_step,
+ )
+ # Save modeling.
+
+ if (args.save_interval > 0 and (step + 1) % args.save_interval == 0) or (step + 1) == len(dataloader):
+ coordinator.print_on_master("\nStart saving model checkpoint with running states")
+ save_checkpoint(
+ save_dir=args.save_dir,
+ booster=booster,
+ model=model,
+ optimizer=optimizer,
+ lr_scheduler=lr_scheduler,
+ epoch=epoch,
+ step=step + 1,
+ batch_size=args.micro_batch_size,
+ coordinator=coordinator,
+ )
+ coordinator.print_on_master(
+ f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
+ )
+
+ # Delete CUDA cache.
+ # del batch, batch_labels, batch_output, loss
+ torch.cuda.empty_cache()
+
+ # the continue epochs are not resumed, so we need to reset the sampler start index and start step
+ dataloader.sampler.set_start_index(start_index=0)
+ start_step = 0
+
+ # Final save.
+ coordinator.print_on_master("Start saving final model checkpoint")
+ booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
+ coordinator.print_on_master(
+ f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}"
+ )
+
+ coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/applications/Colossal-LLaMA-2/version.txt b/applications/Colossal-LLaMA-2/version.txt
new file mode 100644
index 000000000000..8a9ecc2ea99d
--- /dev/null
+++ b/applications/Colossal-LLaMA-2/version.txt
@@ -0,0 +1 @@
+0.0.1
\ No newline at end of file
diff --git a/applications/ColossalEval/README.md b/applications/ColossalEval/README.md
new file mode 100644
index 000000000000..3f645fe7892c
--- /dev/null
+++ b/applications/ColossalEval/README.md
@@ -0,0 +1,560 @@
+
+
+
+
+
+
+## Table of Contents
+
+- [Overview](#overview)
+- [Leaderboard](#leaderboard)
+- [Install](#install)
+- [Evaluation Process](#evaluation-process)
+ - [Inference](#inference)
+ - [Dataset Preparation](#dataset-preparation)
+ - [Configuration](#configuration)
+ - [How to Use](#how-to-use)
+ - [Evaluation](#evaluation)
+ - [Dataset Evaluation](#dataset-evaluation)
+ - [Configuration](#dataset-evaluation)
+ - [How to Use](#dataset-evaluation)
+ - [GPT Evaluation](#gpt-evaluation)
+ - [Configuration](#gpt-evaluation)
+ - [How to Use](#gpt-evaluation)
+- [More Details](#more-details)
+ - [Inference Details](#inference-details)
+ - [Evaluation Details](#evaluation-details)
+ - [Metrics](#metrics)
+ - [examples](#examples)
+ - [Dataset Evaluation Example](#dataset-evaluation-example)
+ - [GPT Evaluation Example](#gpt-evaluation-example)
+- [To Do](#to-do)
+- [FAQ](#faq)
+ - [How to Add a New Metric?](#how-to-add-a-new-metric)
+ - [How to Add a New Dataset?](#how-to-add-a-new-dataset)
+ - [How to Add a New Model?](#how-to-add-a-new-model)
+- [Citations](#citations)
+
+## Overview
+[ColossalEval](https://github.com/hpcaitech/ColossalAI/tree/main/applications/ColossalEval) is a project which provides a uniform pipeline to help evaluate language models on different public dataset or your own dataset using both classic metrics and the help from GPTs. More details can be found in the following sections.
+
+## Leaderboard
+
+We conducted comprehensive evaluation on 4 dataset and compare our Colossal-Llama-2-7b-base model with various models.
+
+- We use 5-shot for MMLU and calculate scores based on the logits of first predicted token.
+- We use 5-shot for CMMLU and calculate scores based on the logits of first predicted token.
+- We use 5-shot for AGIEval and only calculate scores for 4-choice questions using a combination metric of exact match and the logits of first predicted token. If any of the exact match or logits of first predicted token is correct, the model will get the score.
+- We use 0-shot for GAOKAO-Bench and only calculate scores for 4-choice questions based on the logits of first predicted token.
+- The generation config for all dataset is greedy search.
+- We also provided CEval scores from its lastest leaderboard or the official repository of the model.
+
+More details about metrics can be found in [Metrics](#metrics).
+
+| | Backbone | Tokens Consumed | | MMLU | CMMLU | AGIEval | GAOKAO | CEval |
+| :----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :-----: | :----: | :----: | :----------------------------: |
+| | - | - | | 5-shot | 5-shot | 5-shot | 0-shot | 5-shot |
+| Baichuan-7B | - | 1.2T | | 42.32 (42.30) | 44.53 (44.02) | 38.72 | 36.74 | 42.80 |
+| Baichuan-13B-Base | - | 1.4T | | 50.51 (51.60) | 55.73 (55.30) | 47.20 | 51.41 | 53.60 |
+| Baichuan2-7B-Base | - | 2.6T | | 46.97 (54.16) | 57.67 (57.07) | 45.76 | 52.60 | 54.00 |
+| Baichuan2-13B-Base | - | 2.6T | | 54.84 (59.17) | 62.62 (61.97) | 52.08 | 58.25 | 58.10 |
+| ChatGLM-6B | - | 1.0T | | 39.67 (40.63) | 41.17 (-) | 40.10 | 36.53 | 38.90 |
+| ChatGLM2-6B | - | 1.4T | | 44.74 (45.46) | 49.40 (-) | 46.36 | 45.49 | 51.70 |
+| InternLM-7B | - | - | | 46.70 (51.00) | 52.00 (-) | 44.77 | 61.64 | 52.80 |
+| InternLM-20B | - | 2.3T | | 60.96 (62.05) | 59.08 (-) | 57.96 | 61.92 | - |
+| Qwen-7B (original) | - | 2.2T | | 54.29 (56.70) | 56.03 (58.80) | 52.47 | 56.42 | 59.60 |
+| Qwen-7B | - | 2.4T | | 58.33 (58.20) | 62.54 (62.20) | 64.34 | 74.05 | 63.50 |
+| | | | | | | | | |
+| Llama-2-7B | - | 2.0T | | 44.47 (45.30) | 32.97 (-) | 32.60 | 25.46 | - |
+| Linly-AI/Chinese-LLaMA-2-7B-hf | Llama-2-7B | 1.0T | | 37.43 | 29.92 | 32.00 | 27.57 | - |
+| wenge-research/yayi-7b-llama2 | Llama-2-7B | - | | 38.56 | 31.52 | 30.99 | 25.95 | - |
+| ziqingyang/chinese-llama-2-7b | Llama-2-7B | - | | 33.86 | 34.69 | 34.52 | 25.18 | 34.2 |
+| TigerResearch/tigerbot-7b-base | Llama-2-7B | 0.3T | | 43.73 | 42.04 | 37.64 | 30.61 | - |
+| LinkSoul/Chinese-Llama-2-7b | Llama-2-7B | - | | 48.41 | 38.31 | 38.45 | 27.72 | - |
+| FlagAlpha/Atom-7B | Llama-2-7B | 0.1T | | 49.96 | 41.10 | 39.83 | 33.00 | - |
+| IDEA-CCNL/Ziya-LLaMA-13B-v1.1 | Llama-13B | 0.11T | | 50.25 | 40.99 | 40.04 | 30.54 | - |
+| | | | | | | | | |
+| **Colossal-LLaMA-2-7b-base** | Llama-2-7B | **0.0085T** | | 53.06 | 49.89 | 51.48 | 58.82 | 50.20 |
+
+> The score in parentheses corresponds to the scores in the official repository of the model.
+>
+> We use zero-shot for ChatGLM models.
+>
+> To evaluate Qwen-7B on dataset MMLU, the prompt would be "xxx Answer:"(remove the space after ":") and we calculate the logits over " A", " B", " C" and " D" for Qwen-7B. Both the original and updated versions of Qwen-7B tend to be much more deterministic than other models. For example, the logits over " A" can be `-inf` and softmax would be exact `0`.
+>
+> For other models and other dataset, we calculate logits over "A", "B", "C" and "D".
+
+Our model achieves a much better score over all other Llama-1 or Llama-2 based models and also stands out among popular open source LLMs.
+
+## Install
+You should install `ColossalEval` in order to use it and `colossal_eval` is the package installed.
+```bash
+git clone https://github.com/hpcaitech/ColossalAI.git
+cd ColossalAI/applications/ColossalEval
+pip install .
+```
+If you want to add customized dataset or models, use `pip install -e .` in stead to ensure that any changes you make to the source code will immediately affect the package you install.
+
+## Evaluation Process
+The evaluation process involves 2 steps which are `inference` and `evaluation`. You need to set the config for each step.
+
+### Inference
+
+The inference process consists of two parts.
+1. Preprocess and convert the original dataset.
+2. Config your tokenizer and model arguments to perform zero-shot or few-shot prompting.
+
+#### Dataset Preparation
+
+In this step, the original dataset(either in `csv` or `jsonl` format) will be loaded and converted into a `dict`. In the conversion process, we carefully parse each subcategory and assign specific inference arguments for this subcategory.
+
+Inference arguments are stored in a `dict`. The following is an example.
+
+```python
+inference_kwargs = {
+ "calculate_loss": True,
+ "all_classes": ["A", "B", "C", "D"],
+ "language": "Chinese",
+ "pretrain": False,
+ "max_new_tokens": 32
+}
+```
+The `inference_kwargs` currently contains 5 fields:
+
+- `calculate_loss` (bool, compulsory): Whether the loss on target tokens will be calculated
+- `all_classes` (Optional[list], compulsory): Whether the subcategory is a single-choice question. Specify all available options in a list or otherwise None.
+- `language` (str, compulsory): The language for the subcategory.
+- `pretrain` (bool, compulsory): Whether the dataset is a pretrain dataset or not. It is usually used for calculate perplexity when you want to evaluate a model with extended context length.
+- `max_new_tokens` (int, compulsory): The number of new tokens to generate during inference.
+
+For example, for dataset MMLU, each subcategory consists of single-choice questions with options A, B, C and D by default and we can assign value `["A", "B", "C", "D"]` to key`all_classes`. For dataset C-Eval, target answers aren't provided in the test split so `calculate_loss` should be set as False. However, other dataset such as GAOKAO-bench contains different formats of questions and lacks some keys or metadata which can reveal what type (single-choice or multi-choice) of questions it is. Before assigning inference arguments, we first parse the dataset to decide which type of questions the subcategory belongs to and set the inference arguments accordingly.
+
+Other than `inference_kwargs`, `data` is a list containing questions of a same subcategory. The following is a converted dataset.
+
+```json
+{
+ "dev": {
+ "category 1": {"data": [], "inference_kwargs": {}},
+ "category 2": {"data": [], "inference_kwargs": {}}
+ },
+ "test": {
+ "category 1": {"data": [], "inference_kwargs": {}},
+ "category 2": {"data": [], "inference_kwargs": {}}
+ }
+}
+```
+
+A data sample basically follow the format of Alpaca. It should contain the following keys:
+
+* `dataset` (str, compulsory): The name of the dataset.
+* `split` (str, compulsory): The split of the instruction.
+* `catrgory` (str, compulsory): The category of the instruction.
+* `instruction` (str, compulsory): The instruction for the LLM.
+* `input` (str, optional): The additional context of the instruction.
+* `output` (str, optional): The model output of the instruction.
+* `target` (str, optional): The target answer for the instruction.
+
+Example:
+
+```json
+{
+ "dev": {
+ "Abstract Algebra": [
+ {
+ "dataset": "mmlu",
+ "split": "dev",
+ "category": "Abstract Algebra",
+ "instruction": "The following is a single-choice question on Abstract Algebra. Answer the question by replying A, B, C or D.",
+ "input": "Question: Find all c in Z_3 such that Z_3[x]/(x^2 + c) is a field.\nA. 0\nB. 1\nC. 2\nD. 3\nAnswer: ",
+ "output": "",
+ "target": "B"
+ },
+ ]
+ },
+ "test": {
+ "Abstract Algebra": [
+ {
+ "dataset": "mmlu",
+ "split": "test",
+ "category": "Abstract Algebra",
+ "instruction": "The following is a single-choice question on Abstract Algebra. Answer the question by replying A, B, C or D.",
+ "input": "Question: Find the degree for the given field extension Q(sqrt(2), sqrt(3), sqrt(18)) over Q.\nA. 0\nB. 4\nC. 2\nD. 6\nAnswer: ",
+ "output": "",
+ "target": "B"
+ },
+ ]
+ }
+}
+```
+
+#### Configuration
+In this step, you will configure your tokenizer and model arguments to infer on the given datasets.
+
+A config file consists of two parts.
+1. Model config. In model config, you need to specify model name, model path, model class, tokenizer arguments and model arguments. For model class, currently we support `HuggingFaceModel`, `HuggingFaceCausalLM`, `ChatGLMModel` and `ChatGLMModel2`. `HuggingFaceModel` is for models that can be loaded with `AutoModel` and `HuggingFaceCausalLM` is for models that can be loaded with `AutoModelForCausalLM`. `ChatGLMModel` and `ChatGLMModel2` are for ChatGLM and ChatGLM2 models respectively. You can check all model classes in `colossal_eval/models/__init__.py`. If your model should set `trust_remote_code` as true, specify it in the `tokenizer_kwargs` and `model_kwargs` fields.
+2. Dataset config. In dataset config, you need to specify dataset name, path and dataset class. Currently, we support zero-shot on dataset MMLU, CMMLU, AGIEval, GAOKAO-Bench and LongBench and few-shot on dataset MMLU, CMMLU and AGIEval. If you want to enable few shot, set `few_shot` as true. You can check all model classes in `colossal_eval/dataset/__init__.py`.
+
+Once you have all config ready, the program will run inference on all the given datasets on all the given models.
+
+An example config using model class `HuggingFaceCausalLM` and dataset class `CMMLUDataset` can be:
+```json
+{
+ "model": [
+ {
+ "name": "model name",
+ "model_class": "HuggingFaceCausalLM",
+ "parameters": {
+ "path": "path to model",
+ "model_max_length": 2048,
+ "tokenizer_path": "path to tokenizer",
+ "tokenizer_kwargs": {
+ "use_fast": false,
+ "trust_remote_code": true
+ },
+ "peft_path": null,
+ "model_kwargs": {
+ "trust_remote_code": true
+ },
+ "prompt_template": "plain",
+ "batch_size": 4
+ }
+ }
+ ],
+ "dataset": [
+ {
+ "name": "dataset name",
+ "dataset_class": "CMMLUDataset",
+ "debug": false,
+ "few_shot": true,
+ "path": "path to original dataset",
+ "save_path": "path to save converted dataset"
+ }
+ ]
+}
+```
+
+Currently, we support Hugging Face models. The `tokenizer_kwargs` is the arguments used in `AutoTokenizer.from_pretrained()`. The `model_kwargs` is the arguments used in `AutoModel.from_pretrained` or `AutoModelForCausalLM.from_pretrained()`. `few_shot` will be set true if you want to enable few-shot prompting for the dataset. `debug` will be set true if you want to verify whether your prompt is right or wrong.
+
+#### How to Use
+An example script can be the following. The `configs/dataset_evaluation/inference.py` is the same in all examples provided.
+
+```shell
+torchrun --nproc_per_node=1 inference.py \
+ --config "path to config file" \
+ --load_dataset \
+ --inference_save_path "path to save inference results"
+```
+
+You should specify the path to config file in `config`. You can run the script without specifying `load_dataset` if you already save the converted dataset or otherwise set it to first load the original dataset and save the converted dataset. You should specify the path to save inference results in `inference_save_path`.
+
+### Evaluation
+
+In the evaluation process, you only need to configure your evaluation parameters. You can use either public dataset or help from GPTs to do evaluation. We will introduce configuration for dataset evaluation and GPT evaluation.
+
+#### Dataset Evaluation
+
+In dataset evaluation, we calculate different metrics on the given inference results and public dataset.
+
+##### Configuration
+
+A config file for dataset evaluation consists of two parts.
+1. Model config. In model config, you need to specify model name. If you want to evaluate perplexity over a pretrain dataset and calculate per-byte-perplexity, you have to add your tokenizer config and model max length.
+2. Dataset config. In dataset config, you need to specify the evaluation metrics for the dataset.
+
+Once you have all config ready, the program will run evaluation on inference results for all given models and dataset.
+
+An example config can be:
+```json
+{
+ "model": [
+ {
+ "name": "model name"
+ }
+ ],
+ "dataset": [
+ {
+ "name": "dataset name",
+ "metrics": ["first_token_accuracy"]
+ }
+ ]
+}
+```
+
+The above config specifies that the program will evaluate the inference results using `first_token_accuracy` metric.
+
+##### How to Use
+
+An example script can be the following.
+
+```shell
+python eval_dataset.py \
+ --config "path to config file" \
+ --inference_results_path "path to inference results" \
+ --evaluation_results_save_path "path to save evaluation results"
+```
+
+You should specify the path to config file in `config`, the path to inference results in `inference_results_path` and the path to save evaluation results in `evaluation_save_path`.
+
+#### GPT Evaluation
+
+In GPT evaluation, we provide a prompt template which can fit in different pre-defined metrics with Chain-of-Thoughts. In the following sections, we will only introduce how you can evaluate model answers using GPTs. More details can be found in `colossal_eval/evaluate/GPT Evaluation.md`.
+
+##### Configuration
+
+The following is an example of a English config file. The configuration file can control how the pipeline evaluates the model. You need to specify GPT evaluation metrics. You can find an example English config file in `configs/gpt_evaluation`.
+
+```json
+{
+ "language": "en",
+ "category": {
+ "brainstorming": {
+ "GPT": [
+ "language organization",
+ "relevance",
+ "creativity",
+ "practicality",
+ "reasonableness"
+ ]
+ },
+ }
+}
+```
+
+##### How to Use
+After setting the config file, you can evaluate the model using `examples/gpt_evaluation/eval.py`. If you want to make comparisons between answers of two different models, you should specify two answer files in the argument `answer_file_list` and two model names in the argument `model_name_list`(details can be found in `colossal_eval/evaluate/GPT Evaluation.md`). If you want to evaluate one answer file, the length of both `answer_file_list` and `model_name_list` should be 1 and the program will perform evaluation using GPTs. The prompt files for battle and gpt evaluation can be found in `configs/gpt_evaluation/prompt`. `target file` is the path to the converted dataset you save during inference time.
+
+An example script is provided as follows:
+
+```shell
+python eval.py \
+ --config_file "path to the config file" \
+ --battle_prompt_file "path to the prompt file for battle" \
+ --gpt_evaluation_prompt_file "path to the prompt file for gpt evaluation" \
+ --target_file "path to the target answer file" \
+ --answer_file_list "path to the answer file" \
+ --model_name_list "the names of the model" \
+ --gpt_model "which GPT model to use for evaluation" \
+ --save_path "path to save results" \
+ --openai_key "your openai key" \
+```
+
+## More Details
+
+### Inference
+
+In the inference process, we will do generation, calculate loss over target tokens, calculate number of target tokens, softmax over given options (for example, "A", "B", "C", and "D") according to the inference arguments.
+
+For tokenization, we adopt tokenization strategy in [LongBench](https://github.com/THUDM/LongBench/blob/main/pred.py#L55) to preserve crucial instructions on the left and right side and keep all target tokens.
+
+For labeling target tokens, we adopt method from [FastChat](https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py#L137), but it doesn't always hold true due to tokenizers' different behavior. We plan to insert special tokens to correctly label the target tokens.
+
+For calculating loss, we return per-sample-loss instead of per-batch-loss if we directly use `model(batch).loss` provided in HuggingFace.
+
+### Evaluation
+
+To make it more easier to set the config, you only need to specify all metrics you want to use in key `metrics`. However, the program will only use a subset of metrics you give for different subcategories. Applying all metrics to all subcategories is obviously unsuitable. The suggested metrics for specific categories should be defined in `colossal_eval/evaluate/dataset_evaluator/metrics.py`.
+
+#### Metrics
+
+- `combined_single_choice_accuracy`: A combination of `first_token_logit` and `single_choice_accuracy`. If one of these is correct, the model will get the score. It can be used in all dataset that contains single-choice questions.
+- `first_token_logit`: Calculate score based on softmax score over the given choices. If the argmax of the softmax is equal to the reference, the model will get the score. If there is `NaN` in softmax score, it will calculate the score using exact match. It can be used in all dataset that contains single-choice questions.
+- `single_choice_accuracy`: Calculate score using exact match. It will only get the first uppercase letter such as A, B, C or D that is not surrouded by lowercase letters. If the uppercase letter is equal to the reference, the model will get the score. It can be used in all dataset that contains single-choice questions.
+- `multi_choice_accuracy`: Calculate score on multi-choice questions. It will get a set of all uppercase letters such as A, B, C or D that is not surrouded by lowercase letters. If the prediction conatains uppercase letters that are not in reference. The model will get 0 score. If the prediction contains a uppercase letter that is in reference, the model will get a score of `1/len(reference)`. It is used in AGIEval and GAOKAO-Bench.
+- `math_equivalence`: Code from [hendrycks](https://github.com/hendrycks/math/blob/main/modeling/math_equivalence.py). Compute scores over the prediction math formula and reference math formula. It is used in AGIEval and GAOKAO-Bench.
+- `f1_score`: Calculate English f1 score between prediction and reference. It is used in Longbench.
+- `f1_zh_score`: Calculate Chinese f1 score between prediction and reference. It is used in Longbench.
+- `rouge_score`: Calculate English f1 score between prediction and reference. It is used in GAOKAO-Bench and LongBench.
+- `rouge_zh_score`: Calculate Chinese rouge score between prediction and reference. It is used in GAOKAO-Bench and LongBench.
+- `retrieval_score`: Calculate English retrieval score between prediction and reference. It determines whether the ouput(which paragraph) corresponds to the given abstract. It is used in Longbench.
+- `retrieval_zh_score`: Calculate Chinese retrieval score between prediction and reference. It determines whether the ouput(which paragraph) corresponds to the given abstract. It is used in Longbench.
+- `classification_score`: Calculate classification score between prediction and reference. It determines whether the ouput(a class) is equal to the reference. It is used in Longbench.
+- `code_sim_score`: Calculate similarity score between prediction and reference. It is used in Longbench.
+- `count_score`: Calculate count score between prediction and reference. It determines whether the ouput(number of given passages) is equal to the reference. It is used in Longbench.
+- `perplexity`: Calculate perplexity. The formula is $ perplexity = \frac{1}{n} \sum_i e^{loss_i} $ where $n$ is the number of samples and $ loss_i $ is the average loss for sample $ i $. It can be used in all dataset.
+- `ppl_score`: Calculate perplexity score. The formula is $ ppl\_score = \frac{1}{n} \sum_i e^{-loss_i} $ where $n$ is the number of samples and $ loss_i $ is the average loss for sample $ i $. It can be used in all dataset.
+- `ppl_score_over_choices`: Calculate perplexity score over choices. The formula is $ ppl\_score\_over\_choices= \frac{1}{n} \sum_i e^{-loss\_over\_choices_i} $ where $n$ is the number of samples and $ loss\_over\_choices_i $ is the loss on the first predicted token for sample $ i $. It can be used in all dataset that contains single-choice questions.
+- `per_byte_perplexity`: Calculate per byte perplexity. The formula is $ \frac{1}{n} \sum_i e^{\frac{loss_i}{byte_i}} $ where $n$ is the number of samples, $ loss_i $ is the total loss for sample $ i $ and $ byte_i $ is the number of bytes sample $ i $ occupies. It can be used in all dataset.
+- `per_byte_ppl_score`: Calculate per byte perplexity score. The formula is $ \frac{1}{n} \sum_i e^{-\frac{loss_i}{byte_i}} $ where $n$ is the number of samples, $ loss_i $ is the total loss for sample $ i $ and $ byte_i $ is the number of bytes sample $ i $ occupies. It can be used in all dataset.
+
+We use `combined_single_choice_accuracy` and `first_token_logit` in the leaderboard.
+
+### Examples
+
+We provide 2 examples for you to explore our `colossal_eval` package.
+
+#### Dataset Evaluation Example
+
+This example is in folder `examples/dataset_evaluation`.
+
+1. `cd examples/dataset_evaluation`
+2. Fill in your inference config file in `config/inference/config.json`. Set the model and dataset parameters.
+3. Run `inference.sh` to get inference results.
+4. Fill in your evaluation config file in `config/evaluation/config.json`. Set the model and dataset parameters.
+5. Run `eval_dataset.sh` to get evaluation results.
+
+#### GPT Evaluation Example
+
+The examples is in folder `examples/gpt_evaluation`.
+
+1. `cd examples/gpt_evaluation`
+2. Fill in your inference config file in `config/inference/config.json`. Set the model and dataset parameters. If you want to use the example dataset we provide, the dataset is `ColossalDataset`.
+3. Run `inference.sh` to get inference results.
+4. Fill in your evaluation config file in `config/evaluation/config.json`.
+5. Run `eval.sh` to get evaluation results.
+
+## FAQ
+
+### How to Add a New Metric?
+
+If you want to add a customized metric, we recommend using `pip install -e .` to ensure that any changes you make to the source code will immediately affect the package you install.
+
+To add a new metric, you can follow the example of multi_choice_accuracy in line 339 in `colossal_eval/evaluate/dataset_evaluator/metric.py`. The method take one data sample's prediction and reference as input and return a score ranging from 0 to 1.
+
+A skeleton of code is the following.
+
+```python
+
+def CustomizedMetric(prediction: str, reference: str):
+ score = xxx
+ return score
+```
+
+Once you have successfully added your own metric, you should specify your metric both in `colossal_eval/evaluate/dataset_evaluator/metric.py` (suggest which subcategories shoule the metric be applied to) and your evaluation config.
+
+### How to Add a New Dataset?
+
+If you want to add customized dataset, we recommend using `pip install -e .` to ensure that any changes you make to the source code will immediately affect the package you install.
+
+To add a new dataset, you can follow the example of `colossal_eval/dataset/mmlu.py`. You need to make sure that the format of questions in one subcategory should be the same. For example, all questions should have target answers or all questions should be single-choice questions.
+
+A skeleton of code is the following.
+
+```python
+
+class CustomizedDataset(BaseDataset):
+ @staticmethod
+ def load():
+ # 1. Load and convert the original dataset format.
+ # 2. Assign inference arguments for each subcategory.
+ # 3. Return the converted dataset.
+ pass
+```
+
+Once you have successfully added your own dataset, you can specify your dataset class in your inference config.
+
+### How to Add a New Model?
+
+If you want to add customized models, we recommend using `pip install -e .` to ensure that any changes you make to the source code will immediately affect the package you install.
+
+To add a new model, you can follow the example of `colossal_eval/models/huggingface.py`. You need to provide a way to load the model and tokenizer, calculate loss and generate.
+
+A skeleton of code is the following.
+
+```python
+
+class CustomizedModel(BaseModel):
+ def __init__(self):
+ super().__init__()
+ self._load_tokenizer()
+ self._load_model()
+
+ def _load_tokenizer():
+ pass
+
+ def _load_model():
+ pass
+
+ def _calculate_loss():
+ pass
+
+ def get_loss():
+ self._calculate_loss()
+
+ def inference(samples):
+ # 1. Load samples from the same subcategory.
+ # 2. Infer in a batch way according to inference arguments.
+ # 3. Return results.
+ batch_samples = xxx
+ self.get_loss(batch_samples)
+ self.generate(batch_samples)
+
+ return inference_results
+
+ def generate():
+ pass
+```
+
+Once you have successfully added your own model, you can specify your model class in your inference config.
+
+## To do
+
+- [ ] Add visualization code for evaluation results on public dataset
+- [ ] Improve the way to label target tokens
+
+## Citations
+
+```bibtex
+@misc{zhong2023agieval,
+ title={AGIEval: A Human-Centric Benchmark for Evaluating Foundation Models},
+ author={Wanjun Zhong and Ruixiang Cui and Yiduo Guo and Yaobo Liang and Shuai Lu and Yanlin Wang and Amin Saied and Weizhu Chen and Nan Duan},
+ year={2023},
+ eprint={2304.06364},
+ archivePrefix={arXiv},
+ primaryClass={cs.CL}
+}
+
+@article{huang2023ceval,
+title={C-Eval: A Multi-Level Multi-Discipline Chinese Evaluation Suite for Foundation Models},
+author={Huang, Yuzhen and Bai, Yuzhuo and Zhu, Zhihao and Zhang, Junlei and Zhang, Jinghan and Su, Tangjun and Liu, Junteng and Lv, Chuancheng and Zhang, Yikai and Lei, Jiayi and Fu, Yao and Sun, Maosong and He, Junxian},
+journal={arXiv preprint arXiv:2305.08322},
+year={2023}
+}
+
+@misc{li2023cmmlu,
+ title={CMMLU: Measuring massive multitask language understanding in Chinese},
+ author={Haonan Li and Yixuan Zhang and Fajri Koto and Yifei Yang and Hai Zhao and Yeyun Gong and Nan Duan and Timothy Baldwin},
+ year={2023},
+ eprint={2306.09212},
+ archivePrefix={arXiv},
+ primaryClass={cs.CL}
+}
+
+@inproceedings{Zhang2023EvaluatingTP,
+ title={Evaluating the Performance of Large Language Models on GAOKAO Benchmark},
+ author={Xiaotian Zhang and Chunyang Li and Yi Zong and Zhengyu Ying and Liang He and Xipeng Qiu},
+ year={2023}
+}
+
+@misc{bai2023longbench,
+ title={LongBench: A Bilingual, Multitask Benchmark for Long Context Understanding},
+ author={Yushi Bai and Xin Lv and Jiajie Zhang and Hongchang Lyu and Jiankai Tang and Zhidian Huang and Zhengxiao Du and Xiao Liu and Aohan Zeng and Lei Hou and Yuxiao Dong and Jie Tang and Juanzi Li},
+ year={2023},
+ eprint={2308.14508},
+ archivePrefix={arXiv},
+ primaryClass={cs.CL}
+}
+
+@article{hendryckstest2021,
+ title={Measuring Massive Multitask Language Understanding},
+ author={Dan Hendrycks and Collin Burns and Steven Basart and Andy Zou and Mantas Mazeika and Dawn Song and Jacob Steinhardt},
+ journal={Proceedings of the International Conference on Learning Representations (ICLR)},
+ year={2021}
+}
+
+@article{hendrycks2021ethics,
+ title={Aligning AI With Shared Human Values},
+ author={Dan Hendrycks and Collin Burns and Steven Basart and Andrew Critch and Jerry Li and Dawn Song and Jacob Steinhardt},
+ journal={Proceedings of the International Conference on Learning Representations (ICLR)},
+ year={2021}
+}
+
+@misc{zheng2023judging,
+ title={Judging LLM-as-a-judge with MT-Bench and Chatbot Arena},
+ author={Lianmin Zheng and Wei-Lin Chiang and Ying Sheng and Siyuan Zhuang and Zhanghao Wu and Yonghao Zhuang and Zi Lin and Zhuohan Li and Dacheng Li and Eric. P Xing and Hao Zhang and Joseph E. Gonzalez and Ion Stoica},
+ year={2023},
+ eprint={2306.05685},
+ archivePrefix={arXiv},
+ primaryClass={cs.CL}
+}
+
+```
diff --git a/applications/ColossalEval/colossal_eval/__init__.py b/applications/ColossalEval/colossal_eval/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/applications/ColossalEval/colossal_eval/dataset/__init__.py b/applications/ColossalEval/colossal_eval/dataset/__init__.py
new file mode 100644
index 000000000000..4ea173198f5a
--- /dev/null
+++ b/applications/ColossalEval/colossal_eval/dataset/__init__.py
@@ -0,0 +1,19 @@
+from .agieval import AGIEvalDataset
+from .base import BaseDataset
+from .ceval import CEvalDataset
+from .cmmlu import CMMLUDataset
+from .colossalai import ColossalDataset
+from .gaokaobench import GaoKaoBenchDataset
+from .longbench import LongBenchDataset
+from .mmlu import MMLUDataset
+
+__all__ = [
+ "AGIEvalDataset",
+ "BaseDataset",
+ "CEvalDataset",
+ "CMMLUDataset",
+ "GaoKaoBenchDataset",
+ "LongBenchDataset",
+ "MMLUDataset",
+ "ColossalDataset",
+]
diff --git a/applications/ColossalEval/colossal_eval/dataset/agieval.py b/applications/ColossalEval/colossal_eval/dataset/agieval.py
new file mode 100644
index 000000000000..92ebd65931ed
--- /dev/null
+++ b/applications/ColossalEval/colossal_eval/dataset/agieval.py
@@ -0,0 +1,247 @@
+# Adapted from https://github.com/ruixiangcui/AGIEval/blob/main/src/dataset_loader.py.
+
+import ast
+import glob
+import os
+from copy import deepcopy
+from typing import Dict, List
+
+import pandas as pd
+from colossal_eval.utils import get_json_list
+
+from colossalai.logging import DistributedLogger
+
+from .base import BaseDataset
+
+# define the datasets
+english_qa_datasets = [
+ "lsat-ar",
+ "lsat-lr",
+ "lsat-rc",
+ "logiqa-en",
+ "sat-math",
+ "sat-en",
+ "aqua-rat",
+ "sat-en-without-passage",
+ "gaokao-english",
+]
+chinese_qa_datasets = [
+ "logiqa-zh",
+ "jec-qa-kd",
+ "jec-qa-ca",
+ "gaokao-chinese",
+ "gaokao-geography",
+ "gaokao-history",
+ "gaokao-biology",
+ "gaokao-chemistry",
+ "gaokao-physics",
+ "gaokao-mathqa",
+]
+english_cloze_datasets = ["math"]
+chinese_cloze_datasets = ["gaokao-mathcloze"]
+
+multi_choice_datasets = ["jec-qa-kd", "jec-qa-ca", "gaokao-physics", "gaokao-mathqa"]
+math_output_datasets = {"gaokao-mathcloze", "math"}
+
+default_inference_kwargs = {
+ "calculate_loss": True,
+ "all_classes": None,
+ "language": "Chinese",
+ "pretrain": False,
+ "max_new_tokens": 32,
+}
+
+
+def get_prompt(line: Dict, dataset_name: str, logger: DistributedLogger) -> Dict:
+ """Modified from https://github.com/microsoft/AGIEval/blob/main/src/dataset_loader.py#L190"""
+ try:
+ all_classes = None
+ passage = line["passage"] if line["passage"] is not None else ""
+
+ if dataset_name in english_qa_datasets:
+ option_string = "ABCDEFG"
+ count = len(line["options"])
+
+ input = (
+ "Question: "
+ + line["question"]
+ + " "
+ + "Choose from the following options: "
+ + " ".join(line["options"])
+ + "\n"
+ + "Answer: "
+ )
+
+ all_classes = list(option_string[0:count])
+
+ elif dataset_name in chinese_qa_datasets:
+ option_string = "ABCDEFG"
+ count = len(line["options"])
+
+ input = "问题:" + line["question"] + " " + "从以下选项中选择:" + " ".join(line["options"]) + "\n" + "答案:"
+
+ all_classes = list(option_string[0:count])
+
+ elif dataset_name in english_cloze_datasets:
+ input = "Question: " + line["question"] + "\n" + "Answer: "
+
+ elif dataset_name in chinese_cloze_datasets:
+ input = "问题:" + line["question"] + "\n" + "答案:"
+
+ return {
+ "instruction": input if not passage else passage + "\n\n" + input,
+ "target": line["label"] if line["label"] else line["answer"],
+ }, all_classes
+
+ except NameError:
+ logger.info("Dataset not defined.")
+
+
+# process few-shot raw_prompts
+def combine_prompt(prompt_path, dataset_name, load_explanation=True, chat_mode=False):
+ skip_passage = False
+ if dataset_name == "sat-en-without-passage":
+ skip_passage = True
+ dataset_name = "sat-en"
+ demostrations = []
+ # read the prompts by context and explanation
+ context_row = [0, 1, 3, 5, 7, 9]
+ explanation_row = [0, 2, 4, 6, 8, 10]
+ raw_prompts_context = pd.read_csv(
+ prompt_path, header=0, skiprows=lambda x: x not in context_row, keep_default_na=False
+ )
+ raw_prompts_explanation = pd.read_csv(
+ prompt_path, header=0, skiprows=lambda x: x not in explanation_row, keep_default_na=False
+ ).replace(r"\n\n", "\n", regex=True)
+ contexts = []
+ for line in list(raw_prompts_context[dataset_name]):
+ if line:
+ # print(line)
+ contexts.append(ast.literal_eval(line))
+ explanations = [exp for exp in raw_prompts_explanation[dataset_name] if exp]
+
+ for idx, (con, exp) in enumerate(zip(contexts, explanations)):
+ passage = con["passage"] if con["passage"] is not None and not skip_passage else ""
+ question = con["question"]
+ options = con["options"] if con["options"] is not None else ""
+ label = con["label"] if con["label"] is not None else ""
+ answer = con["answer"] if "answer" in con and con["answer"] is not None else ""
+
+ if dataset_name in english_qa_datasets:
+ question_input = (
+ "Question: "
+ + passage
+ + " "
+ + question
+ + "\n"
+ + "Choose from the following options: "
+ + " ".join(options)
+ + "\n"
+ + "Answer: {}".format(label)
+ )
+ elif dataset_name in chinese_qa_datasets:
+ question_input = (
+ "问题:" + passage + " " + question + "\n" + "从以下选项中选择:" + " ".join(options) + "\n" + "答案:{}".format(label)
+ )
+ elif dataset_name in english_cloze_datasets:
+ question_input = "Question: ".format(idx + 1) + question + "\n" + "Answer: {}".format(answer)
+ elif dataset_name in chinese_cloze_datasets:
+ question_input = "问题:" + question + "\n" + "答案:{}".format(answer)
+ else:
+ raise ValueError(f"During loading few-sot examples, found unknown dataset: {dataset_name}")
+
+ if chat_mode:
+ demostrations.append((question_input,))
+ else:
+ demostrations.append(question_input + "\n")
+
+ return demostrations
+
+
+class AGIEvalDataset(BaseDataset):
+ """
+ Dataset wrapper for AGIEval dataset.
+ Data source: https://github.com/microsoft/AGIEval
+ This dataset class will convert the original dataset into the inference dataset.
+
+ A few dirty data needed to be manually corrected in the origin dataset:
+ Issue link: https://github.com/microsoft/AGIEval/issues/16
+ 1. Invalid options in line 190 in gaokao-chemistry.jsonl.
+ 2. Option D (They may increase in value as those same resources become rare on Earth.) missing in line 17 in sat-en-without-passage.jsonl.
+ 3. Option D (They may increase in value as those same resources become rare on Earth.) missing in line 17 in sat-en.jsonl.
+ 4. Option D (No, because the data do not indicate whether the honeybees had been infected with mites.) missing in line 57 in sat-en-without-passage.jsonl.
+ 5. Option D (No, because the data do not indicate whether the honeybees had been infected with mites.) missing in line 57 in sat-en.jsonl.
+ 6. Option D (Published theories of scientists who developed earlier models of the Venus flytrap) missing in line 98 in sat-en-without-passage.jsonl.
+ 7. Option D (Published theories of scientists who developed earlier models of the Venus flytrap) missing in line 98 in sat-en.jsonl.
+ 8. Label is empty in line 212 in jec-qa-kd.jsonl. Content is also dirty.
+ 9. Actually, gaokao-mathqa.jsonl is also a multi-choice dataset. See line 149 286 287.
+ """
+
+ @staticmethod
+ def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
+ dataset = {"test": {}}
+
+ files = glob.glob(os.path.join(path, "*.jsonl"))
+ files.sort()
+
+ if few_shot:
+ prompt_path = os.path.join(path, "few_shot_prompts.csv")
+
+ for file in files:
+ dataset_name = os.path.basename(file)[0 : -len(".jsonl")]
+
+ few_shot_data = []
+ if few_shot:
+ # process demo once if it is few-shot-CoT
+ few_shot_data = combine_prompt(prompt_path, dataset_name, load_explanation=False, chat_mode=False)
+
+ dataset["test"][dataset_name] = {"data": []}
+
+ file_dir = os.path.join(path, file)
+
+ loaded_jsonl = get_json_list(file_dir)
+
+ # It's been tested that each data sample in one subcategory have same inference arguments.
+ _, all_classes = get_prompt(loaded_jsonl[0], dataset_name, logger)
+ inference_kwargs = deepcopy(default_inference_kwargs)
+ if all_classes is not None and dataset_name not in multi_choice_datasets:
+ inference_kwargs["all_classes"] = all_classes
+
+ if dataset_name in english_qa_datasets:
+ inference_kwargs["language"] = "English"
+ if dataset_name in chinese_qa_datasets:
+ inference_kwargs["language"] = "Chinese"
+ inference_kwargs["few_shot_data"] = few_shot_data
+
+ dataset["test"][dataset_name]["inference_kwargs"] = inference_kwargs
+
+ for line in loaded_jsonl:
+ info, all_classes = get_prompt(line, dataset_name, logger)
+
+ # Convert multi-choice answers to a single string.
+ # We will convert it back when evaluating.
+ # We do this because if target is a list, it should be only used for multiple target answers.
+ if dataset_name in multi_choice_datasets:
+ if isinstance(info["target"], str) and len(info["target"]) > 1:
+ # "gaokao-mathqa" actually contain multi-choice questions.
+ # This if clause is specially used for it.
+ info["target"] = "".join(info["target"].split())
+ else:
+ info["target"] = "".join(info["target"])
+
+ if isinstance(info["target"], list) and len(info["target"]) == 1:
+ info["target"] = info["target"][0]
+
+ data_sample = {
+ "dataset": "agieval",
+ "split": "test",
+ "category": dataset_name,
+ "instruction": info["instruction"],
+ "input": "",
+ "output": "",
+ "target": info["target"],
+ }
+
+ dataset["test"][dataset_name]["data"].append(data_sample)
+
+ return dataset
diff --git a/applications/ColossalEval/colossal_eval/dataset/base.py b/applications/ColossalEval/colossal_eval/dataset/base.py
new file mode 100644
index 000000000000..45b0151b849f
--- /dev/null
+++ b/applications/ColossalEval/colossal_eval/dataset/base.py
@@ -0,0 +1,24 @@
+from abc import abstractstaticmethod
+
+from colossal_eval.utils import jdump
+
+
+class BaseDataset:
+ """
+ Base class for dataset wrapper.
+
+ Args:
+ path: The path to the original dataset.
+ logger: Logger for the dataset.
+ """
+
+ def __init__(self, path, logger, few_shot):
+ self.dataset = self.load(path, logger, few_shot)
+
+ def save(self, save_path):
+ """Save the converted dataset"""
+ jdump(self.dataset, save_path)
+
+ @abstractstaticmethod
+ def load(path, logger):
+ """Load the original dataset and convert it into the inference dataset"""
diff --git a/applications/ColossalEval/colossal_eval/dataset/ceval.py b/applications/ColossalEval/colossal_eval/dataset/ceval.py
new file mode 100644
index 000000000000..32ec52087bd3
--- /dev/null
+++ b/applications/ColossalEval/colossal_eval/dataset/ceval.py
@@ -0,0 +1,132 @@
+import copy
+import csv
+import os
+from typing import Dict, List
+
+from colossalai.logging import DistributedLogger
+
+from .base import BaseDataset
+
+ceval_subject_mapping = {
+ "computer_network": ["Computer Network", "计算机网络", "STEM"],
+ "operating_system": ["Operating System", "操作系统", "STEM"],
+ "computer_architecture": ["Computer Architecture", "计算机组成", "STEM"],
+ "college_programming": ["College Programming", "大学编程", "STEM"],
+ "college_physics": ["College Physics", "大学物理", "STEM"],
+ "college_chemistry": ["College Chemistry", "大学化学", "STEM"],
+ "advanced_mathematics": ["Advanced Mathematics", "高等数学", "STEM"],
+ "probability_and_statistics": ["Probability and Statistics", "概率统计", "STEM"],
+ "discrete_mathematics": ["Discrete Mathematics", "离散数学", "STEM"],
+ "electrical_engineer": ["Electrical Engineer", "注册电气工程师", "STEM"],
+ "metrology_engineer": ["Metrology Engineer", "注册计量师", "STEM"],
+ "high_school_mathematics": ["High School Mathematics", "高中数学", "STEM"],
+ "high_school_physics": ["High School Physics", "高中物理", "STEM"],
+ "high_school_chemistry": ["High School Chemistry", "高中化学", "STEM"],
+ "high_school_biology": ["High School Biology", "高中生物", "STEM"],
+ "middle_school_mathematics": ["Middle School Mathematics", "初中数学", "STEM"],
+ "middle_school_biology": ["Middle School Biology", "初中生物", "STEM"],
+ "middle_school_physics": ["Middle School Physics", "初中物理", "STEM"],
+ "middle_school_chemistry": ["Middle School Chemistry", "初中化学", "STEM"],
+ "veterinary_medicine": ["Veterinary Medicine", "兽医学", "STEM"],
+ "college_economics": ["College Economics", "大学经济学", "Social Science"],
+ "business_administration": ["Business Administration", "工商管理", "Social Science"],
+ "marxism": ["Marxism", "马克思主义基本原理", "Social Science"],
+ "mao_zedong_thought": ["Mao Zedong Thought", "毛泽东思想和中国特色社会主义理论体系概论", "Social Science"],
+ "education_science": ["Education Science", "教育学", "Social Science"],
+ "teacher_qualification": ["Teacher Qualification", "教师资格", "Social Science"],
+ "high_school_politics": ["High School Politics", "高中政治", "Social Science"],
+ "high_school_geography": ["High School Geography", "高中地理", "Social Science"],
+ "middle_school_politics": ["Middle School Politics", "初中政治", "Social Science"],
+ "middle_school_geography": ["Middle School Geography", "初中地理", "Social Science"],
+ "modern_chinese_history": ["Modern Chinese History", "近代史纲要", "Humanities"],
+ "ideological_and_moral_cultivation": ["Ideological and Moral Cultivation", "思想道德修养与法律基础", "Humanities"],
+ "logic": ["Logic", "逻辑学", "Humanities"],
+ "law": ["Law", "法学", "Humanities"],
+ "chinese_language_and_literature": ["Chinese Language and Literature", "中国语言文学", "Humanities"],
+ "art_studies": ["Art Studies", "艺术学", "Humanities"],
+ "professional_tour_guide": ["Professional Tour Guide", "导游资格", "Humanities"],
+ "legal_professional": ["Legal Professional", "法律职业资格", "Humanities"],
+ "high_school_chinese": ["High School Chinese", "高中语文", "Humanities"],
+ "high_school_history": ["High School History", "高中历史", "Humanities"],
+ "middle_school_history": ["Middle School History", "初中历史", "Humanities"],
+ "civil_servant": ["Civil Servant", "公务员", "Other"],
+ "sports_science": ["Sports Science", "体育学", "Other"],
+ "plant_protection": ["Plant Protection", "植物保护", "Other"],
+ "basic_medicine": ["Basic Medicine", "基础医学", "Other"],
+ "clinical_medicine": ["Clinical Medicine", "临床医学", "Other"],
+ "urban_and_rural_planner": ["Urban and Rural Planner", "注册城乡规划师", "Other"],
+ "accountant": ["Accountant", "注册会计师", "Other"],
+ "fire_engineer": ["Fire Engineer", "注册消防工程师", "Other"],
+ "environmental_impact_assessment_engineer": ["Environmental Impact Assessment Engineer", "环境影响评价工程师", "Other"],
+ "tax_accountant": ["Tax Accountant", "税务师", "Other"],
+ "physician": ["Physician", "医师资格", "Other"],
+}
+
+default_inference_kwargs = {
+ "calculate_loss": False,
+ "all_classes": ["A", "B", "C", "D"],
+ "language": "Chinese",
+ "pretrain": False,
+ "max_new_tokens": 32,
+}
+
+
+def get_few_shot_data(data: List[Dict]):
+ few_shot_data = []
+ for i in data:
+ few_shot_data.append(i["input"] + i["target"])
+ return few_shot_data
+
+
+class CEvalDataset(BaseDataset):
+ """
+ Dataset class for CEval dataset.
+ Data source: https://huggingface.co/datasets/ceval/ceval-exam
+ This dataset class will convert the original dataset into the inference dataset.
+ """
+
+ @staticmethod
+ def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
+ dataset = {"dev": {}, "test": {}}
+ for split in ["dev", "test"]:
+ files = os.listdir(os.path.join(path, split))
+ files.sort()
+
+ for file in files:
+ subject = file[0 : -len(f"_{split}.csv")]
+ subject = ceval_subject_mapping[subject][1]
+
+ file_dir = os.path.join(path, split, file)
+
+ dataset[split][subject] = {"data": []}
+
+ # It's been tested that each data sample in one subcategory have same inference arguments.
+ dataset[split][subject]["inference_kwargs"] = copy.deepcopy(default_inference_kwargs)
+
+ if split == "test" and few_shot:
+ dataset[split][subject]["inference_kwargs"]["few_shot_data"] = get_few_shot_data(
+ dataset["dev"][subject]["data"]
+ )
+
+ with open(file_dir, encoding="utf-8") as f:
+ reader = csv.reader(f)
+ _ = next(reader)
+ for row in reader:
+ # Dev split have answer and explanation so len(row) is 8
+ # But test split doesn't contain answer and explanation, so len(row) is 6
+ assert len(row) >= 6
+ choices = f"A. {row[2]}\nB. {row[3]}\nC. {row[4]}\nD. {row[5]}"
+ data_sample = {
+ "dataset": "ceval",
+ "split": split,
+ "category": subject,
+ "instruction": f"以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。",
+ "input": f"题目:{row[1]}\n{choices}\n答案:",
+ "output": "",
+ "target": row[6] if split == "dev" else "",
+ "id": int(row[0]),
+ }
+
+ dataset[split][subject]["data"].append(data_sample)
+
+ return dataset
diff --git a/applications/ColossalEval/colossal_eval/dataset/cmmlu.py b/applications/ColossalEval/colossal_eval/dataset/cmmlu.py
new file mode 100644
index 000000000000..51f8ca14e0c8
--- /dev/null
+++ b/applications/ColossalEval/colossal_eval/dataset/cmmlu.py
@@ -0,0 +1,144 @@
+import copy
+import csv
+import os
+from typing import Dict, List
+
+from colossalai.logging import DistributedLogger
+
+from .base import BaseDataset
+
+cmmlu_subject_mapping = {
+ "agronomy": "农学",
+ "anatomy": "解剖学",
+ "ancient_chinese": "古汉语",
+ "arts": "艺术学",
+ "astronomy": "天文学",
+ "business_ethics": "商业伦理",
+ "chinese_civil_service_exam": "中国公务员考试",
+ "chinese_driving_rule": "中国驾驶规则",
+ "chinese_food_culture": "中国饮食文化",
+ "chinese_foreign_policy": "中国外交政策",
+ "chinese_history": "中国历史",
+ "chinese_literature": "中国文学",
+ "chinese_teacher_qualification": "中国教师资格",
+ "clinical_knowledge": "临床知识",
+ "college_actuarial_science": "大学精算学",
+ "college_education": "大学教育学",
+ "college_engineering_hydrology": "大学工程水文学",
+ "college_law": "大学法律",
+ "college_mathematics": "大学数学",
+ "college_medical_statistics": "大学医学统计",
+ "college_medicine": "大学医学",
+ "computer_science": "计算机科学",
+ "computer_security": "计算机安全",
+ "conceptual_physics": "概念物理学",
+ "construction_project_management": "建设工程管理",
+ "economics": "经济学",
+ "education": "教育学",
+ "electrical_engineering": "电气工程",
+ "elementary_chinese": "小学语文",
+ "elementary_commonsense": "小学常识",
+ "elementary_information_and_technology": "小学信息技术",
+ "elementary_mathematics": "初等数学",
+ "ethnology": "民族学",
+ "food_science": "食品科学",
+ "genetics": "遗传学",
+ "global_facts": "全球事实",
+ "high_school_biology": "高中生物",
+ "high_school_chemistry": "高中化学",
+ "high_school_geography": "高中地理",
+ "high_school_mathematics": "高中数学",
+ "high_school_physics": "高中物理学",
+ "high_school_politics": "高中政治",
+ "human_sexuality": "人类性行为",
+ "international_law": "国际法学",
+ "journalism": "新闻学",
+ "jurisprudence": "法理学",
+ "legal_and_moral_basis": "法律与道德基础",
+ "logical": "逻辑学",
+ "machine_learning": "机器学习",
+ "management": "管理学",
+ "marketing": "市场营销",
+ "marxist_theory": "马克思主义理论",
+ "modern_chinese": "现代汉语",
+ "nutrition": "营养学",
+ "philosophy": "哲学",
+ "professional_accounting": "专业会计",
+ "professional_law": "专业法学",
+ "professional_medicine": "专业医学",
+ "professional_psychology": "专业心理学",
+ "public_relations": "公共关系",
+ "security_study": "安全研究",
+ "sociology": "社会学",
+ "sports_science": "体育学",
+ "traditional_chinese_medicine": "中医中药",
+ "virology": "病毒学",
+ "world_history": "世界历史",
+ "world_religions": "世界宗教",
+}
+
+default_inference_kwargs = {
+ "calculate_loss": True,
+ "all_classes": ["A", "B", "C", "D"],
+ "language": "Chinese",
+ "pretrain": False,
+ "max_new_tokens": 32,
+}
+
+
+def get_few_shot_data(data: List[Dict]):
+ few_shot_data = []
+ for i in data:
+ few_shot_data.append(i["input"] + i["target"])
+ return few_shot_data
+
+
+class CMMLUDataset(BaseDataset):
+ """
+ Dataset class for CMMLU dataset.
+ Data source: https://github.com/haonan-li/CMMLU/tree/master/data
+ This dataset class will convert the original dataset into the inference dataset.
+ """
+
+ @staticmethod
+ def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
+ dataset = {"dev": {}, "test": {}}
+ for split in ["dev", "test"]:
+ files = os.listdir(os.path.join(path, split))
+ files.sort()
+
+ for file in files:
+ subject = file[0 : -len(".csv")]
+ subject = cmmlu_subject_mapping[subject]
+
+ file_dir = os.path.join(path, split, file)
+
+ dataset[split][subject] = {"data": []}
+
+ # It's been tested that each data sample in one subcategory have same inference arguments.
+ dataset[split][subject]["inference_kwargs"] = copy.deepcopy(default_inference_kwargs)
+
+ if split == "test" and few_shot:
+ dataset[split][subject]["inference_kwargs"]["few_shot_data"] = get_few_shot_data(
+ dataset["dev"][subject]["data"]
+ )
+
+ with open(file_dir, encoding="utf-8") as f:
+ reader = csv.reader(f)
+ _ = next(reader)
+ for row in reader:
+ assert len(row) == 7
+ choices = f"A. {row[2]}\nB. {row[3]}\nC. {row[4]}\nD. {row[5]}"
+ data_sample = {
+ "dataset": "cmmlu",
+ "split": split,
+ "category": subject,
+ "instruction": f"以下是关于{subject}的单项选择题,请直接给出正确答案的选项。",
+ "input": f"题目:{row[1]}\n{choices}\n答案:",
+ "output": "",
+ "target": row[6],
+ }
+
+ dataset[split][subject]["data"].append(data_sample)
+
+ return dataset
diff --git a/applications/ColossalEval/colossal_eval/dataset/colossalai.py b/applications/ColossalEval/colossal_eval/dataset/colossalai.py
new file mode 100644
index 000000000000..54ea478ae5d6
--- /dev/null
+++ b/applications/ColossalEval/colossal_eval/dataset/colossalai.py
@@ -0,0 +1,70 @@
+from collections import defaultdict
+from copy import deepcopy
+from typing import Dict, List
+
+from colossal_eval.utils import jload
+
+from colossalai.logging import DistributedLogger
+
+from .base import BaseDataset
+
+default_inference_kwargs = {
+ "calculate_loss": False,
+ "all_classes": None,
+ "language": "Chinese",
+ "pretrain": False,
+ "max_new_tokens": 256,
+}
+
+# You can add your own subcategory questions and specify whether it is a single-choice question or has target answers and need to calculate loss.
+single_choice_question = set()
+calculate_loss = set()
+
+
+def get_data_per_category(data):
+ data_per_category = defaultdict(list)
+ for item in data:
+ category = item["category"]
+ data_per_category[category].append(item)
+
+ return data_per_category
+
+
+class ColossalDataset(BaseDataset):
+ """
+ Dataset class for Colossal dataset.
+ This dataset class will convert the original dataset into the inference dataset.
+ """
+
+ @staticmethod
+ def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
+ dataset = {"test": {}}
+ data = jload(path)
+ data_per_category = get_data_per_category(data)
+ categories = list(data_per_category.keys())
+
+ for category in categories:
+ dataset["test"][category] = {"data": []}
+ category_data = data_per_category[category]
+
+ dataset["test"][category]["inference_kwargs"] = deepcopy(default_inference_kwargs)
+
+ if category in calculate_loss:
+ dataset["test"][category]["inference_kwargs"]["calculate_loss"] = True
+ if category in single_choice_question:
+ dataset["test"][category]["inference_kwargs"]["all_classes"] = ["A", "B", "C", "D"]
+
+ for item in category_data:
+ data_sample = {
+ "dataset": "colossal",
+ "split": "test",
+ "category": category,
+ "instruction": item["instruction"],
+ "input": item["input"],
+ "output": "",
+ "target": item["target"],
+ "id": item["id"],
+ }
+ dataset["test"][category]["data"].append(data_sample)
+
+ return dataset
diff --git a/applications/ColossalEval/colossal_eval/dataset/gaokaobench.py b/applications/ColossalEval/colossal_eval/dataset/gaokaobench.py
new file mode 100644
index 000000000000..7bf0639e4882
--- /dev/null
+++ b/applications/ColossalEval/colossal_eval/dataset/gaokaobench.py
@@ -0,0 +1,122 @@
+import json
+import os
+import re
+from copy import deepcopy
+from typing import Dict, List
+
+from colossalai.logging import DistributedLogger
+
+from .base import BaseDataset
+
+multi_choice_datasets = [
+ "Chinese Lang and Usage MCQs",
+ "Chinese Modern Lit",
+ "English Fill in Blanks",
+ "English Reading Comp",
+ "Geography MCQs",
+ "Physics MCQs",
+ "English Cloze Test",
+]
+
+chinese_qa_datasets = [
+ "Biology MCQs",
+ "Chemistry MCQs",
+ "Chinese Lang and Usage MCQs",
+ "Chinese Modern Lit",
+ "Geography MCQs",
+ "History MCQs",
+ "Math I MCQs",
+ "Math II MCQs",
+ "Physics MCQs",
+ "Political Science MCQs",
+]
+english_qa_datasets = ["English MCQs", "English Fill in Blanks", "English Reading Comp", "English Cloze Test"]
+
+default_inference_kwargs = {
+ "calculate_loss": True,
+ "all_classes": None,
+ "language": "Chinese",
+ "pretrain": False,
+ "max_new_tokens": 32,
+}
+
+
+def get_all_classes(instruction: str):
+ letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
+ pattern = r"([A-Z]\. |[A-Z].|[A-Z]\.)"
+ options = sorted(list(set(re.findall(pattern, instruction))))
+ options = sorted(list(set([string[0] for string in options])))
+
+ for i in range(len(options)):
+ if options[i] == letters[i]:
+ continue
+ else:
+ return options[0:i]
+ return options
+
+
+class GaoKaoBenchDataset(BaseDataset):
+ """
+ Dataset class for GAOKAO-Bench dataset.
+ Data source: https://github.com/OpenLMLab/GAOKAO-Bench/tree/main/data
+ This dataset class will convert the original dataset into the inference dataset.
+
+ A few typos needed to be manually corrected in the origin dataset, some of the following is fixed.
+ Issue link: https://github.com/OpenLMLab/GAOKAO-Bench/issues/20
+ 1. Option C missing in index 111 in 2010-2022_Chemistry_MCQs.json
+ 2. Option B missing "." after it in index 16 in 2012-2022_English_Cloze_Test.json
+ 3. Option G missing "." after it in index 23 in 2012-2022_English_Cloze_Test.json
+ """
+
+ @staticmethod
+ def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
+ dataset = {"test": {}}
+ for category in ["Fill-in-the-blank_Questions", "Multiple-choice_Questions", "Open-ended_Questions"]:
+ files = os.listdir(os.path.join(path, "data", category))
+ files.sort()
+
+ for file in files:
+ subject = file[10:-5].split("_")
+ subject = " ".join(subject)
+ dataset["test"][subject] = {"data": []}
+
+ file_dir = os.path.join(path, "data", category, file)
+
+ with open(file_dir, encoding="utf-8") as f:
+ data = json.load(f)
+
+ # It's been tested that each data sample in one subcategory have same inference arguments.
+ inference_kwargs = deepcopy(default_inference_kwargs)
+ if category == "Multiple-choice_Questions" and subject not in multi_choice_datasets:
+ all_classes = get_all_classes(data["example"][0]["question"])
+ inference_kwargs["all_classes"] = all_classes
+ if subject in english_qa_datasets:
+ inference_kwargs["language"] = "English"
+ if subject in chinese_qa_datasets:
+ inference_kwargs["language"] = "Chinese"
+
+ dataset["test"][subject]["inference_kwargs"] = inference_kwargs
+
+ for sample in data["example"]:
+ # Convert multi-choice answers to a single string.
+ # We will convert it back when evaluating.
+ # We do this because if target is a list, it should be only used for multiple target answers.
+ if subject in multi_choice_datasets:
+ sample["answer"] = "".join(sample["answer"])
+
+ if isinstance(sample["answer"], list) and len(sample["answer"]) == 1:
+ sample["answer"] = sample["answer"][0]
+
+ data_sample = {
+ "dataset": "gaokaobench",
+ "split": "test",
+ "category": f"{category[:-10]}-{subject}",
+ "instruction": sample["question"].strip() + "\n答案:",
+ "input": "",
+ "output": "",
+ "target": sample["answer"],
+ }
+
+ dataset["test"][subject]["data"].append(data_sample)
+
+ return dataset
diff --git a/applications/ColossalEval/colossal_eval/dataset/longbench.py b/applications/ColossalEval/colossal_eval/dataset/longbench.py
new file mode 100644
index 000000000000..9ea5e3c7d77f
--- /dev/null
+++ b/applications/ColossalEval/colossal_eval/dataset/longbench.py
@@ -0,0 +1,120 @@
+import os
+from copy import deepcopy
+from typing import Dict, List
+
+from colossal_eval.utils import get_json_list
+
+from colossalai.logging import DistributedLogger
+
+from .base import BaseDataset
+
+dataset2prompt = {
+ "narrativeqa": "You are given a story, which can be either a novel or a movie script, and a question. Answer the question asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nStory: {context}\n\nNow, answer the question based on the story asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:",
+ "qasper": 'You are given a scientific article and a question. Answer the question as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write "unanswerable". If the question is a yes/no question, answer "yes", "no", or "unanswerable". Do not provide any explanation.\n\nArticle: {context}\n\n Answer the question based on the above article as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write "unanswerable". If the question is a yes/no question, answer "yes", "no", or "unanswerable". Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:',
+ "multifieldqa_en": "Read the following text and answer briefly.\n\n{context}\n\nNow, answer the following question based on the above text, only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
+ "multifieldqa_zh": "阅读以下文字并用中文简短回答:\n\n{context}\n\n现在请基于上面的文章回答下面的问题,只告诉我答案,不要输出任何其他字词。\n\n问题:{input}\n回答:",
+ "hotpotqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
+ "2wikimqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
+ "musique": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
+ "dureader": "请基于给定的文章回答下述问题。\n\n文章:{context}\n\n请基于上述文章回答下面的问题。\n\n问题:{input}\n回答:",
+ "gov_report": "You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{context}\n\nNow, write a one-page summary of the report.\n\nSummary:",
+ "qmsum": "You are given a meeting transcript and a query containing a question or instruction. Answer the query in one or more sentences.\n\nTranscript:\n{context}\n\nNow, answer the query based on the above meeting transcript in one or more sentences.\n\nQuery: {input}\nAnswer:",
+ "multi_news": "You are given several news passages. Write a one-page summary of all news. \n\nNews:\n{context}\n\nNow, write a one-page summary of all the news.\n\nSummary:",
+ "vcsum": "下面有一段会议记录,请你阅读后,写一段总结,总结会议的内容。\n会议记录:\n{context}\n\n会议总结:",
+ "trec": "Please determine the type of the question below. Here are some examples of questions.\n\n{context}\n{input}",
+ "triviaqa": "Answer the question based on the given passage. Only give me the answer and do not output any other words. The following are some examples.\n\n{context}\n\n{input}",
+ "samsum": "Summarize the dialogue into a few short sentences. The following are some examples.\n\n{context}\n\n{input}",
+ "lsht": "请判断给定新闻的类别,下面是一些例子。\n\n{context}\n{input}",
+ "passage_count": "There are some paragraphs below sourced from Wikipedia. Some of them may be duplicates. Please carefully read these paragraphs and determine how many unique paragraphs there are after removing duplicates. In other words, how many non-repeating paragraphs are there in total?\n\n{context}\n\nPlease enter the final count of unique paragraphs after removing duplicates. The output format should only contain the number, such as 1, 2, 3, and so on.\n\nThe final answer is: ",
+ "passage_retrieval_en": 'Here are 30 paragraphs from Wikipedia, along with an abstract. Please determine which paragraph the abstract is from.\n\n{context}\n\nThe following is an abstract.\n\n{input}\n\nPlease enter the number of the paragraph that the abstract is from. The answer format must be like "Paragraph 1", "Paragraph 2", etc.\n\nThe answer is: ',
+ "passage_retrieval_zh": '以下是若干段落文字,以及其中一个段落的摘要。请确定给定的摘要出自哪一段。\n\n{context}\n\n下面是一个摘要\n\n{input}\n\n请输入摘要所属段落的编号。答案格式必须是"段落1","段落2"等格式\n\n答案是:',
+ "lcc": "Please complete the code given below. \n{context}Next line of code:\n",
+ "repobench-p": "Please complete the code given below. \n{context}{input}Next line of code:\n",
+}
+
+dataset2maxlen = {
+ "narrativeqa": 128,
+ "qasper": 128,
+ "multifieldqa_en": 64,
+ "multifieldqa_zh": 64,
+ "hotpotqa": 32,
+ "2wikimqa": 32,
+ "musique": 32,
+ "dureader": 128,
+ "gov_report": 512,
+ "qmsum": 512,
+ "multi_news": 512,
+ "vcsum": 512,
+ "trec": 64,
+ "triviaqa": 32,
+ "samsum": 128,
+ "lsht": 64,
+ "passage_count": 32,
+ "passage_retrieval_en": 32,
+ "passage_retrieval_zh": 32,
+ "lcc": 64,
+ "repobench-p": 64,
+}
+
+default_inference_kwargs = {
+ "calculate_loss": True,
+ "all_classes": None,
+ "language": "Chinese",
+ "pretrain": False,
+ "max_new_tokens": 32,
+}
+
+
+class LongBenchDataset(BaseDataset):
+ """
+ Dataset class for LongBench dataset.
+ Data source: https://huggingface.co/datasets/THUDM/LongBench
+ This dataset class will convert the original dataset into the inference dataset.
+
+ Issue link: https://github.com/THUDM/LongBench/issues/15 (fixed)
+ There are duplicate target answers in `nq.jsonl`, but this doesn't affect evaluation results.
+ Also doesn't affect perplexity calculation (the program only need to select the minimum loss).
+ """
+
+ @staticmethod
+ def load(path: str, logger: DistributedLogger) -> List[Dict]:
+ dataset = {"test": {}}
+
+ files = os.listdir(path)
+ files.sort()
+
+ for file in files:
+ category = file[0:-6]
+
+ if category.endswith("_e"):
+ continue
+
+ dataset["test"][category] = {"data": []}
+
+ file_dir = os.path.join(path, file)
+
+ loaded_jsonl = get_json_list(file_dir)
+
+ # It's been tested that each data sample in one subcategory have same inference arguments.
+ inference_kwargs = deepcopy(default_inference_kwargs)
+ if loaded_jsonl[0]["all_classes"] is not None:
+ inference_kwargs["all_classes"] = loaded_jsonl[0]["all_classes"]
+ inference_kwargs["max_new_tokens"] = dataset2maxlen[category]
+ dataset["test"][category]["inference_kwargs"] = inference_kwargs
+
+ for sample in loaded_jsonl:
+ prompt = dataset2prompt[category].format(**sample)
+
+ data_sample = {
+ "dataset": "longbench",
+ "split": "test",
+ "category": category,
+ "instruction": prompt,
+ "input": "",
+ "output": "",
+ "target": sample["answers"],
+ }
+
+ dataset["test"][category]["data"].append(data_sample)
+
+ return dataset
diff --git a/applications/ColossalEval/colossal_eval/dataset/mmlu.py b/applications/ColossalEval/colossal_eval/dataset/mmlu.py
new file mode 100644
index 000000000000..b89c0a13cff1
--- /dev/null
+++ b/applications/ColossalEval/colossal_eval/dataset/mmlu.py
@@ -0,0 +1,73 @@
+import copy
+import csv
+import os
+from typing import Dict, List
+
+from colossalai.logging import DistributedLogger
+
+from .base import BaseDataset
+
+default_inference_kwargs = {
+ "calculate_loss": True,
+ "all_classes": ["A", "B", "C", "D"],
+ "language": "English",
+ "pretrain": False,
+ "max_new_tokens": 32,
+}
+
+
+def get_few_shot_data(data: List[Dict]):
+ few_shot_data = []
+ for i in data:
+ few_shot_data.append(i["input"] + i["target"])
+ return few_shot_data
+
+
+class MMLUDataset(BaseDataset):
+ """
+ Dataset class for MMLU dataset.
+ Data source: https://github.com/hendrycks/test
+ This dataset class will convert the original dataset into the inference dataset.
+ """
+
+ @staticmethod
+ def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
+ dataset = {"dev": {}, "test": {}}
+ for split in ["dev", "test"]:
+ files = os.listdir(os.path.join(path, split))
+ files.sort()
+
+ for file in files:
+ subject = file[0 : -len(f"_{split}.csv")].split("_")
+ subject = " ".join([word.title() if word != "us" else "US" for word in subject])
+
+ file_dir = os.path.join(path, split, file)
+
+ dataset[split][subject] = {"data": [], "inference_kwargs": {}}
+
+ # It's been tested that each data sample in one subcategory have same inference arguments.
+ dataset[split][subject]["inference_kwargs"] = copy.deepcopy(default_inference_kwargs)
+
+ if split == "test" and few_shot:
+ dataset[split][subject]["inference_kwargs"]["few_shot_data"] = get_few_shot_data(
+ dataset["dev"][subject]["data"]
+ )
+
+ with open(file_dir, encoding="utf-8") as f:
+ reader = csv.reader(f)
+ for row in reader:
+ assert len(row) == 6
+ choices = f"A. {row[1]}\nB. {row[2]}\nC. {row[3]}\nD. {row[4]}"
+ data_sample = {
+ "dataset": "mmlu",
+ "split": split,
+ "category": subject,
+ "instruction": f"The following is a single-choice question on {subject}. Answer the question by replying A, B, C or D.",
+ "input": f"Question: {row[0]}\n{choices}\nAnswer: ",
+ "output": "",
+ "target": row[5],
+ }
+
+ dataset[split][subject]["data"].append(data_sample)
+
+ return dataset
diff --git a/applications/ColossalEval/colossal_eval/evaluate/GPT Evaluation.md b/applications/ColossalEval/colossal_eval/evaluate/GPT Evaluation.md
new file mode 100644
index 000000000000..37fbda4c8647
--- /dev/null
+++ b/applications/ColossalEval/colossal_eval/evaluate/GPT Evaluation.md
@@ -0,0 +1,248 @@
+# GPT Evaluation
+## Table of Contents
+- [Overview](#overview)
+- [GPT Evaluation](#gpt-evaluation)
+ - [Evaluation Category](#evaluation-category)
+ - [Evaluation Category Examples](#evaluation-category-examples)
+ - [Evaluation Metrics](#evaluation-metrics)
+- [Evaluation Process](#evaluation-process)
+ - [Data Format](#data-format)
+ - [Prompt](#prompt)
+ - [Battle Prompt](#battle-prompt)
+ - [Evaluation Prompt](#evaluation-prompt)
+ - [Evaluation](#evaluation)
+ - [Configuration](#configuration)
+ - [Evaluate](#evaluate)
+- [FAQ](#faq)
+- [Citations](#citations)
+
+
+## Overview
+
+In this directory, we introduce how you can evaluate your model using GPTs. It is now available for evaluation of both Chinese and English capability and we provide the following functions:
+
+* Compare the performance of two different models (battle).
+* Rate the model according to pre-defined metrics using prompting design.
+* Rate the model according to pre-defined metrics with additional reference answer using prompting design.
+
+## GPT Evaluation
+
+### Evaluation Category
+
+Our evaluation pipeline can examine the model's capability using different categories of questions. The following table includes some example categories. You can add your own questions.
+
+| Evaluation Category | Description |
+| :-----------------: | :----------------------------------------------------------- |
+| Brainstorming | Models are asked to generate a range of creative and diverse ideas according to the question. The capability of creativity is required. |
+| Chat | Models are asked to continue a multi-round dialogue given the roles involved. The capability of understanding, memorizing previous rounds of the dialogue and answering according to the persona provided is required. |
+| Generation | Models are asked to generate an email, letter, article, etc. The capability of generating texts in a high quality and human-written way is required. |
+| Open QA | Models are asked to answer an open QA question(without context provided). The capability of answering questions with the models' own knowledge base is required. |
+| Roleplay | Models are asked to play the role provided. The capability of engaging in the scenario and effectively interacting with the user is required. |
+
+
+### Evaluation Category Examples
+To better understand each evaluation category, here are some example questions provided. Example questions are in the `configs/gpt_evaluation/data` folder.
+
+
+| Evaluation Category | Chinese Example | English Example |
+| :-----------------: | :----------------------------------------------------------- | :----------------------------------------------------------- |
+| Brainstorming | 列举一些可以促进头发生长的食物。 | How do you properly chop an onion without crying? |
+| Chat | 基于以下角色信息完成一段对话。小张是一名新手爱好者,对养鸡有浓厚的兴趣。老李是一名有丰富经验的养鸡大师。
小张:您好,老李,我最近开始对养鸡感兴趣了,想请教您一些问题。
老李:你好,小张,我很乐意帮助你。你想问些什么?
小张:我想知道如何确定鸡的品种和性别?
老李:确切的品种可以通过鸡的外貌特征来确定,而性别一般是通过鸡卵的大小和形状来判断。还有什么问题吗?
小张:
| Complete a dialogue based on the following character information. Alex: A novice writer who is struggling to find inspiration and develop his writing skills. Emma: A successful author with many published works, providing guidance and advice to Alex.
Alex: Hi Emma, I have been writing for a while now but can't seem to make any progress. Can you give me any advice?
Emma: Hi Alex, sure. What kind of writing are you doing?
Alex: I'm trying to write a novel, but I just can't seem to find any inspiration.
Emma:
|
+| Generation | 请为一家咖啡店编写一篇简短的广告语,吸引更多的顾客。 | Write a set of guidelines for first-time pet owners on how to properly care for a new puppy. |
+| Open QA | 解释什么是RNA病毒和DNA病毒。 | Explain the process of osmosis in biological systems. |
+| Roleplay | 我要你把我写的句子翻译成表情符号。我会写句子,你会用表情符号表达它。我只是想让你用表情符号来表达它。除了表情符号,我不希望你回复任何内容。当我需要用中文告诉你一些事情时,我会用 {} 这样的大括号括起来。我的第一句话是“{我的职业是消防员。}” | I want you to act as a rapper. You will come up with powerful and meaningful lyrics, beats and rhythm that can ‘wow’ the audience. Your lyrics should have an intriguing meaning and message which people can relate too. When it comes to choosing your beat, make sure it is catchy yet relevant to your words, so that when combined they make an explosion of sound everytime! My first request is "I need a rap song about finding strength within yourself." |
+
+### Evaluation Metrics
+
+GPT evaluation uses GPT models to evaluate the prediction of different models and different pre-defined evaluation metrics are applied to different categories. The following table shows the 10 pre-defined evaluation metrics both in Chinese and English:
+
+| Evaluation Metric | Prompt Words | CoT(Chain-of-Thought) |
+| :-------------------: | :----------------------------------------------------------- | :----------------------------------------------------------- |
+| 语言组织
(Language organization) | 语言组织(1-5):答案语言是否流畅、连贯,使用正确的语法,具有一定逻辑性,使用恰当的连接词、过渡词等等。Language organization (1-5): whether the answer language is fluent and coherent, uses correct grammar, has a certain logic, uses appropriate connecting words, transition words, etc. | 1. 阅读答案,并检查是否有语法错误、用词不当或其他显著的错误。
2. 检查答案是否具有逻辑性,能够按照合理的顺序传达信息并且能够自圆其说
3. 确定答案是否与问题或主题相关,并且能够传达清晰的信息。
4. 检查答案是否连贯,是否使用适当的转换和过渡来保持句子和段落之间的连贯性。
5. 检查答案是否具有明确的结构和组织方式,使得读者可以轻松理解信息的层次和结构。
6. 根据以上因素综合评估答案的语言组织,并给出一个1到5的分数,其中5表示语言组织非常好,而1表示语言组织非常差。1. Read the answers and check for grammatical errors, poor word choice, or other significant mistakes.
2. Check that the answer is logical, conveys the information in a logical order, and is self-explanatory.
3. Determine if the answer is relevant to the question or topic and conveys a clear message.
4. Check that the answer is coherent and that appropriate transitions and switches are used to maintain coherence between sentences and paragraphs.
5. Check that the answer is clearly structured and organized in such a way that the reader can easily understand the hierarchy and structure of the information.
6. Evaluate the linguistic organization of the answer based on a combination of the above factors and give a score of 1 to 5, where 5 indicates very good linguistic organization and 1 indicates very poor linguistic organization. |
+| 切题
(Relevance) | 切题(1-5):答案内容是否切题,不答非所问,并且严格遵照题目要求。Relevance (1-5): whether the content of the answer is relevant to the topic, does not answer the wrong question, and strictly follows the requirements of the topic. | 1. 阅读题目,确定题目所问的问题是什么,以及需要回答哪些方面的问题。
2. 阅读答案,确认答案是否直接回答了题目所问的问题。
3. 检查答案是否严格遵照了题目的要求,包括答题方式、答题长度、答题格式等等。
4. 根据以上因素综合评估答案的切题程度,并给出一个1到5的分数,其中5表示答案非常切题,而1表示答案完全没有切题。1. Read the question to determine what the question asks and what aspects of the question need to be answered.
2. Read the answers to make sure that they directly answer the question asked.
3. Check that the answer follows the requirements of the question, including the way it is answered, the length of the answer, the format of the answer, etc.
4. Evaluate how relevant the answer is based on the above factors and give a score of 1 to 5, where 5 means the answer is very relevant and 1 means the answer is not relevant at all. |
+| 创意性
(Creativity) | 创意性(1-5):某些头脑风暴问题可能需要答案具有创意,提出新的思路。Creativity (1-5): Some brainstorming questions may require answers that are creative and suggest new ideas. | 1. 仔细阅读所提供的头脑风暴问题,确保你理解问题的要点和背景。
2. 根据你的知识和经验,判断所提供的答案是否可行。如果答案不可行,则创意性评分可能会受到影响。
3. 考虑答案中是否包含新颖的想法或独特的思路。答案可能与已知的解决方案有所重叠,但仍然可以被认为是有创意的,只要它提供了新的角度或方法来解决问题。
4. 根据答案的创意性,给出一个1到5的评分。如果答案缺乏创意,则应给出一个较低的评分。如果答案具有创意并提供了新的思路,应给出一个较高的评分。1. Read the provided brainstorming questions carefully to make sure you understand the gist and context of the questions.
2. Based on your knowledge and experience, determine if the answers provided are feasible. If the answer is not feasible, the creativity score may be affected.
3. Consider whether the answer contains novel ideas or unique thoughts. An answer may overlap with a known solution and still be considered creative, as long as it offers a new perspective or approach to the problem.
4. Give a score of 1 to 5 depending on the creativity of the answer. If the answer lacks creativity, a lower score should be given. If the answer is creative and provides a new idea, a higher score should be given. |
+| 实用性
(Practicality) | 实用性(1-5):某些头脑风暴问题可能需要答案提出实用的建议或解决方法。Practicality (1-5): Some brainstorming questions may require answers to suggest practical suggestions or solutions. | 1. 仔细阅读所提供的头脑风暴问题,确保你理解问题的要点和背景。
2. 根据你的知识和经验,判断所提供的答案是否可行。如果答案不可行,则实用性评分可能会受到影响。
3. 考虑答案中提出的建议或解决方法是否实用并可行。答案可能看起来很好,但如果无法实现或应用,则实用性评分可能会受到影响。
4. 根据答案的实用性,给出一个1到5的评分。如果答案缺乏实用性,则应给出一个较低的评分。如果答案提出了实用的建议或解决方法,并且可以很好地解决问题,则应给出一个较高的评分。1. Read the provided brainstorming questions carefully to make sure you understand the gist and context of the questions.
2. Based on your knowledge and experience, determine if the answers provided are feasible. If the answer is not feasible, the practicality score may be affected.
3. Consider whether the suggestions or solutions presented in the answer are practical and workable. The answer may look good, but if it cannot be implemented or applied, the practicality score may be affected.
4. Give a score of 1 to 5 depending on the practicality of the answer. If the answer lacks practicality, a lower score should be given. If the answer makes a practical suggestion or solution and solves the problem well, a higher score should be given. |
+| 正确性
(Correctness) | 正确性(1-5):正确性(1-5):答案是否正确。 Correctness (1-5): whether the answer is correct or not. | 1. 仔细阅读题目,尝试自己回答该问题。
2. 检查答案的准确性。您可以使用已知的事实或研究来验证答案是否正确。如果答案是正确的,则可以将正确性得分为5分。如果答案是部分正确的,则可以给予适当的得分,例如2分、3分或4分。如果答案完全不正确,则只得1分。
1. Read the question carefully and try to answer the question yourself.
2. Check the correctness of the answer. You can use known facts or research to verify that the answer is correct. If the answer is correct, you can give a score of 5 for correctness. If the answer is partially correct, an appropriate score, such as 2, 3, or 4, may be given. If the answer is completely incorrect, only 1 point is awarded. |
+| 自然
(Naturalness) | 自然(1-5):答案是否自然,并且符合问题给定的身份。Naturalness (1-5): whether the answer is natural and fits the identity given by the question. | 1. 阅读题目,确定题目提供的身份信息。
2. 检查答案内容是否符合题目给定的身份。
3. 根据以上因素,对该回答的自然性进行打分,分数从1到5,其中1表示不自然,5表示非常自然,并符合问题给定的身份。1. Read the question and determine the identity information provided in the question.
2. Check whether the content of the answer matches the identity given in the question.
3. Based on the above factors, score the naturalness of the response on a scale from 1 to 5, where 1 means unnatural and 5 means very natural and in accordance with the identity given in the question. |
+| 参与感
(Engagingness) | 参与感(1-5):答案是否对前面的对话内容做出了恰当的反应,是否理解对话的语境和背景。Engagingness (1-5): whether the answer responds appropriately to the content of the preceding conversation and whether it understands the context and background of the conversation. | 1. 阅读题目,确定对话的语境和背景。
2. 检查答案是否充分理解对话的语境和背景,能否自然地融入到对话中而不显得突兀。
3. 根据以上因素,对该回答的参与感进行打分,分数从1到5,其中1表示没有参与感,5表示非常有参与感,并且恰当地理解了对话的语境和背景。1. Read the questions to determine the context and background of the dialogue.
2. Check that the answer fully understands the context and background of the conversation and that it fits naturally into the conversation without seeming abrupt.
3. Based on the above factors, rate the response's engagement on a scale from 1 to 5, where 1 means not engaged and 5 means very engaged and appropriately understands the context and background of the conversation. |
+| 合理性
(Reasonableness) | 合理性(1-5):答案是否能够与前面的对话内容形成逻辑上的衔接,是否符合常理,能否在这个上下文中合理存在。Reasonableness (1-5): Whether the answer can form a logical connection with the content of the previous dialogue, whether it is consistent with common sense, and whether it can reasonably exist in this context. | 1. 阅读题目,确定对话的主题以及问题期望的回答方向。
2. 判断答案是否能够与前面的对话内容形成逻辑上的衔接,是否符合常理,能否在这个上下文中合理存在。
3. 根据以上因素,对该回答的合理性进行打分,分数从1到5,其中1表示不合理,5表示非常合理,并且能够与前面的对话内容形成逻辑上的衔接,并符合常理。1. Read the question and determine the topic of the conversation and the direction the question expects the answer to go.
2. Determine whether the answer can be logically connected to the preceding conversation, whether it makes common sense, and whether it can reasonably exist in this context.
3. Based on the above factors, rate the reasonableness of the answer on a scale from 1 to 5, where 1 means unreasonable and 5 means very reasonable and able to form a logical connection with the preceding dialogue content and consistent with common sense. |
+| 多样性
(Diversity) | 多样性(1-5):答案使用语言是否优美,具有有一定的创造性和想象力。然而,回答也应该保持合理和适度,不要过于夸张或离题。Diversity (1-5): Whether the answers use beautiful language and have some creativity and imagination. However, answers should also be kept reasonable and moderate, not overly exaggerated or off-topic. | 1. 仔细阅读整个回答,确保完全理解回答所表达的内容和主题。
2. 在阅读回答的同时,注意语言的质量,例如措辞是否正确,语言是否生动等。
3. 检查回答的创造性和想象力,看看回答是否能够吸引人阅读下去。
4. 检查回答的合理性和适度,看看回答是否夸张或离题。5. 将多样性的评分打分在1到5之间,5分表示回答的质量很好,能够吸引人阅读,1分表示回答的内容生硬或者有离题的问题。1. Read the entire response carefully to ensure that you fully understand the content and theme expressed in the response.
2. While reading the response, pay attention to the quality of the language, such as whether the wording is correct and the language is vivid.
3. Check the creativity and imagination of the response to see if the response is engaging to read on.
4. Check the reasonableness and appropriateness of the responses to see if the responses are exaggerated or off-topic.
5. Rate the diversity on a scale of 1 to 5, with a 5 indicating a good quality response that is engaging to read and a 1 indicating a raw response or a question that is off-topic. |
+| 保真度
(Fidelity) | 保真度(1-5):答案是否能够严格遵守角色的设定回答给定的请求。Fidelity (1-5): whether the answer is able to answer the given request in strict compliance with the role setting. | 1. 仔细阅读问题,了解角色在问题中的设定和表现,包括职业、背景、观点、性格等方面。
阅读题目的请求,确认回答请求时需要注意的细节。
3. 对比提供的回答与该角色的设定,评估回答是否能够严格遵守角色的设定。
4. 结合以上评估结果给出保真度的评分,范围从1到5分,其中1分表示回答与角色设定完全不符,5分表示回答完全符合角色设定且满足给定请求。1. Read the question carefully to understand how the character is set up and represented in the question, including aspects such as occupation, background, point of view, and personality.
2. Read the question's request and confirm the details that need to be taken into account when answering the request.
3. Compare the provided answer with the setting of the role and assess whether the answer can strictly adhere to the setting of the role.
4. Combine the results of the above assessment to give a fidelity score ranging from 1 to 5, where a score of 1 means that the response does not match the persona at all, and a score of 5 means that the response fully complies with the persona and satisfies the given request. |
+
+GPT models evaluate the quality of model predictions based on the given prompt words and gives a score between 1-5.
+
+> **NOTE 1:** You can find all the prompt words and CoT(Chain-of-Thought) in `configs/gpt_evaluation/prompt/evaluation_prompt`.
+
+> **NOTE 2:** To add customized metrics, you can refer to [FAQ](#faq).
+
+## Evaluation Process
+
+### Data Format
+
+A JSON file contains one list. Each element in the list is a target answer / prediction record for one instruction / question.
+An element should have the following fields:
+
+* `category` (str, compulsory): The category of the instruction / question.
+* `instruction` (str, compulsory): The instruction / question for the LLM.
+* `input` (str, optional): The additional context of the instruction / question.
+* `output` (str, optional): The model output of the instruction, models will fill in this field during inference time.
+* `target` (str, optional): The target answer for the instruction.
+* `id` (int, compulsory): The ID of the instruction / question.
+
+Example:
+
+```json
+[
+ {
+ "category": "brainstorming",
+ "instruction": "请问如何制作一份美味的西红柿炒鸡蛋?",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 1
+ },
+ {
+ "category": "chat",
+ "instruction": "基于以下角色信息完成一段对话。小张是一名新手爱好者,对养鸡有浓厚的兴趣。老李是一名有丰富经验的养鸡大师。",
+ "input": "小张:您好,老李,我最近开始对养鸡感兴趣了,想请教您一些问题。 老李:你好,小张,我很乐意帮助你。你想问些什么? 小张:我想知道如何确定鸡的品种和性别? 老李:确切的品种可以通过鸡的外貌特征来确定,而性别一般是通过鸡卵的大小和形状来判断。还有什么问题吗? 小张:",
+ "output": "",
+ "target": "",
+ "id": 2
+ }
+]
+```
+
+### Prompt
+
+#### Battle Prompt
+
+The following is the Chinese battle prompt. In the battle prompt, the question and answers from two different models are fed into the prompt template. You can find example battle prompt files for Chinese and English in `configs/gpt_evaluation/prompt/battle_prompt`.
+
+```json
+{
+ "id": 1,
+ "system_prompt": "你是一个检查回答质量的好助手。",
+ "prompt_template": "[问题]\n{question}\n\n[1号AI助手的答案]\n{answer_1}\n\n[1号AI助手答案终止]\n\n[2号AI助手的答 案]\n{answer_2}\n\n[2号AI助手答案终止]\n\n[要求]\n{prompt}\n\n",
+ "prompt": "我们需要你评价这两个AI助手回答的性能。\n请对他们的回答的有用性、相关性、准确性、详细程度进行评分。每个AI助手都会得到一个1到10分的总分,分数越高表示整体表现越好。\n请首先输出一行,该行只包含两个数值,分别表示1号和2号AI助手的分数。这两个分数之间要有一个空格。在随后的一行中,请对你的评价作出全面的解释,避免任何潜在的偏见,并确保AI助手回答的顺序不会影响您的判断。"
+}
+```
+
+#### Evaluation Prompt
+
+The following is an example of a Chinese GPT evaluation prompt. In an evaluation prompt, you should define your metrics in `metrics` and provide CoT(Chain-of-Thought) in `CoT`. You can find example evaluation prompt files for Chinese and English in `configs/gpt_evaluation/prompt/evaluation_prompt`.
+
+```json
+{
+ "brainstorming": {
+ "id": 1,
+ "category": "brainstorming",
+ "metrics": {
+ "language organization": "语言组织(1-5):答案语言是否流畅、连贯,使用正确的语法,具有一定逻辑性,使用恰当的连接词、过渡词等等。"
+ },
+ "CoT": {
+ "language organization": "1. 阅读答案,并检查是否有语法错误、用词不当或其他显著的错误。\n2. 检查答案是否具有逻辑性,能够按照合理的顺序传达信息并且能够自圆其说。\n3. 确定答案是否与问题或主题相关,并且能够传达清晰的信息。\n4. 检查答案是否连贯,是否使用适当的转换和过渡来保持句子和段落之间的连贯性。\n5. 检查答案是否具有明确的结构和组织方式,使得读者可以轻松理解信息的层次和结构。\n6. 根据以上因素综合评估答案的语言组织,并给出一个1到5的分数,其中5表示语言组织非常好,而1表示语言组织非常差。\n\n语言组织:"
+ },
+ "prompt": "你是一个好助手。请你为下面“头脑风暴”问题的答案打分。\n\n问题如下:\n\n{question}\n\n答案如下:\n\n{answer}\n\n评分的指标如下:\n\n{metric}\n\n请你遵照以下的评分步骤:\n\n{steps}"
+ }
+}
+```
+
+`"metrics"`: the metrics that can be used in GPT evaluation. This field determines which metrics can be added to your config file.
+
+`"CoT"`: evaluation steps you prompt to GPT models for each metric defined in `"metrics"`.
+
+### Evaluation
+
+#### Configuration
+
+The following is an example of a Chinese config file. The configuration file can control how the pipeline evaluates the model. You need to specify GPT evaluation metrics in key `GPT`. You can find an example English config file in `configs/gpt_evaluation/config/config_en.json`.
+
+```json
+{
+ "language": "cn",
+ "category": {
+ "brainstorming": {
+ "GPT": [
+ "language organization",
+ "relevance",
+ "creativity",
+ "practicality",
+ "reasonableness"
+ ]
+ }
+ }
+}
+```
+
+`"language"`: the language used to evaluate the model capability. We only support Chinese `"cn"` for now.
+
+`"category"`: the category/categories needed to evaluate the model capability.
+
+`"GPT"`: the metrics you want to use for GPT evaluation.
+
+
+#### Evaluate
+
+After setting the configuration file, you can evaluate the model using `examples/gpt_evaluation/eval.py`. If you want to make comparisons between answers of two different models, you should specify two answer files in the argument `answer_file_list` and two model names in the argument `model_name_list`. If you want to evaluate one answer file, the length of both `answer_file_list` and `model_name_list` should be 1 and the program will perform evaluation using automatic metrics and GPT models.
+
+An example script is provided as follows:
+
+```shell
+python eval.py \
+ --config_file "path to the config file" \
+ --battle_prompt_file "path to the prompt file for battle" \
+ --gpt_evaluation_prompt_file "path to the prompt file for gpt evaluation" \
+ --target_file "path to the target answer file" \
+ --answer_file_list "path to the answer files of at most 2 models" \
+ --model_name_list "the names of at most 2 models" \
+ --gpt_model "which GPT model to use for evaluation" \
+ --save_path "path to save results" \
+ --openai_key "your openai key" \
+```
+
+If you want GPT evaluation with reference, you can add an argument `--gpt_with_reference`, but make sure the reference file have target answers.
+
+## FAQ
+
+How can I add a new GPT evaluation metric?
+
+For example, if you want to add a new metric `persuasiveness` into category `brainstorming`, you should add the metric definition and its corresponding CoT(Chain-of-thought) in the evaluation prompt file in `prompt/evaluation_promt`. The CoT can be generated using ChatGPT. You can prompt ChatGPT to generate evaluation steps for the new metric.
+
+```json
+{
+ "brainstorming": {
+ "id": 1,
+ "category": "brainstorming",
+ "metrics": {
+ "persuasiveness": "persuasiveness(1-5):a short description for persuasiveness"
+ },
+ "CoT": {
+ "persuasiveness": "CoT for persuasiveness\n\npersuasiveness:"
+ },
+ "prompt": "You are a good assistant. Please rate the given answer to the \"brainstorming\" question below.\n\nThe question is as follows:\n\n{question}\n\nThe answer is as follows:\n\n{answer}\n\nThe metric for evaluation is as follows:\n\n{metric}\n\nYou should follow the following evaluation steps:\n\n{steps}"
+ }
+}
+```
+
+
+
+## Citations
+
+```bibtex
+@misc{vicuna2023,
+ title = {Vicuna: An Open-Source Chatbot Impressing GPT-4 with 90\%* ChatGPT Quality},
+ url = {https://vicuna.lmsys.org},
+ author = {Chiang, Wei-Lin and Li, Zhuohan and Lin, Zi and Sheng, Ying and Wu, Zhanghao and Zhang, Hao and Zheng, Lianmin and Zhuang, Siyuan and Zhuang, Yonghao and Gonzalez, Joseph E. and Stoica, Ion and Xing, Eric P.},
+ month = {March},
+ year = {2023}
+}
+
+@misc{liu2023geval,
+ title={G-Eval: NLG Evaluation using GPT-4 with Better Human Alignment},
+ author={Yang Liu and Dan Iter and Yichong Xu and Shuohang Wang and Ruochen Xu and Chenguang Zhu},
+ year={2023},
+ eprint={2303.16634},
+ archivePrefix={arXiv},
+ primaryClass={cs.CL}
+}
+```
diff --git a/applications/ColossalEval/colossal_eval/evaluate/__init__.py b/applications/ColossalEval/colossal_eval/evaluate/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/__init__.py b/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/__init__.py
new file mode 100644
index 000000000000..3c5df09a6909
--- /dev/null
+++ b/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/__init__.py
@@ -0,0 +1,3 @@
+from .dataset_evaluator import DatasetEvaluator
+
+__all__ = ["DatasetEvaluator"]
diff --git a/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/dataset_evaluator.py b/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/dataset_evaluator.py
new file mode 100644
index 000000000000..c70988707a15
--- /dev/null
+++ b/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/dataset_evaluator.py
@@ -0,0 +1,269 @@
+from typing import Dict, List
+
+import colossal_eval.evaluate.dataset_evaluator.metrics as metric_helper
+import numpy as np
+import tqdm
+
+LabelBasedMetrics = ["first_token_accuracy", "matthews_correlation"]
+LossBasedMetrics = ["perplexity", "ppl_score", "ppl_score_over_choices", "per_byte_perplexity", "per_byte_ppl_score"]
+CombinedMetrics = ["combined_single_choice_accuracy"]
+OtherMetrics = [
+ "f1_score",
+ "f1_zh_score",
+ "rouge_score",
+ "rouge_zh_score",
+ "retrieval_score",
+ "retrieval_zh_score",
+ "classification_score",
+ "code_sim_score",
+ "count_score",
+ "multi_choice_accuracy",
+ "math_equivalence",
+ "single_choice_accuracy",
+]
+
+
+class DatasetEvaluator(object):
+ """
+ Dataset evaluator.
+
+ """
+
+ def __init__(self):
+ pass
+
+ def _calculate_label_metrics(self, metric: str, category: str):
+ """Calculate label-based metrics."""
+ weight = len(self.data[category]["data"]) / self.metric_total_length[metric]
+
+ str_label_map = {
+ choice: idx for idx, choice in enumerate(self.data[category]["inference_kwargs"]["all_classes"])
+ }
+
+ references = [str_label_map[sample["target"]] for sample in self.data[category]["data"]]
+ [sample["output"] for sample in self.data[category]["data"]]
+
+ flag = False
+ softmaxs = []
+ for i, sample in enumerate(self.data[category]["data"]):
+ if np.any(np.isnan(np.array(list(sample["softmax_over_choices"].values())))):
+ if not flag:
+ print(
+ f"NaN in the softmax, switch to exact match for category {category} in dataset {self.dataset_name} in model {self.model_name}."
+ )
+ flag = True
+ score = 0
+ for ref in sample["target"]:
+ score = max(
+ score,
+ metric_helper.single_choice_accuracy(
+ sample["output"], ref, all_classes=self.data[category]["inference_kwargs"]["all_classes"]
+ ),
+ )
+ softmaxs.append(references[i] if score == 1 else -1)
+ else:
+ softmaxs.append(np.argmax(np.array(list(sample["softmax_over_choices"].values()))))
+
+ references = np.array(references)
+ softmaxs = np.array(softmaxs)
+ scores = np.sum(references == softmaxs) / len(self.data[category]["data"]) * 100
+
+ self.evaluation_results[metric][category] = (scores, len(self.data[category]["data"]))
+ self.evaluation_results[metric]["ALL"] += scores * weight
+
+ def _calculate_combined_metrics(self, metric: str, category: str):
+ """Calculate combined metrics."""
+ weight = len(self.data[category]["data"]) / self.metric_total_length[metric]
+
+ references = [sample["target"] for sample in self.data[category]["data"]]
+ predictions = [sample["output"] for sample in self.data[category]["data"]]
+
+ str_label_map = {
+ choice: idx for idx, choice in enumerate(self.data[category]["inference_kwargs"]["all_classes"])
+ }
+
+ references_labels = [str_label_map[sample["target"][0]] for sample in self.data[category]["data"]]
+ predictions = [sample["output"] for sample in self.data[category]["data"]]
+
+ flag = False
+ softmaxs = []
+ for i, sample in enumerate(self.data[category]["data"]):
+ if np.any(np.isnan(np.array(list(sample["softmax_over_choices"].values())))):
+ if not flag:
+ print(
+ f"NaN in the softmax, switch to exact match for category {category} in dataset {self.dataset_name} in model {self.model_name}."
+ )
+ flag = True
+ score = 0
+ for ref in sample["target"]:
+ score = max(
+ score,
+ metric_helper.single_choice_accuracy(
+ sample["output"], ref, all_classes=self.data[category]["inference_kwargs"]["all_classes"]
+ ),
+ )
+ softmaxs.append(references[i] if score == 1 else -1)
+ else:
+ softmaxs.append(np.argmax(np.array(list(sample["softmax_over_choices"].values()))))
+
+ metric_method = eval("metric_helper." + metric)
+
+ total_score = 0.0
+ for prediction, reference, references_label, softmax in zip(
+ predictions, references, references_labels, softmaxs
+ ):
+ score = 0.0
+
+ for ref in reference:
+ score = max(
+ score,
+ metric_method(prediction, ref, all_classes=self.data[category]["inference_kwargs"]["all_classes"]),
+ )
+ if references_label == softmax:
+ score = 1
+
+ total_score += score
+ total_score = total_score * 100 / len(self.data[category]["data"])
+
+ self.evaluation_results[metric][category] = (total_score, len(self.data[category]["data"]))
+ self.evaluation_results[metric]["ALL"] += total_score * weight
+
+ def _calculate_other_metrics(self, metric: str, category: str):
+ """Calculate other metrics."""
+ weight = len(self.data[category]["data"]) / self.metric_total_length[metric]
+
+ references = [sample["target"] for sample in self.data[category]["data"]]
+ predictions = [sample["output"] for sample in self.data[category]["data"]]
+
+ metric_method = eval("metric_helper." + metric)
+
+ total_score = 0.0
+ for prediction, reference in zip(predictions, references):
+ score = 0.0
+ for ref in reference:
+ score = max(
+ score,
+ metric_method(prediction, ref, all_classes=self.data[category]["inference_kwargs"]["all_classes"]),
+ )
+ total_score += score
+ total_score = total_score * 100 / len(predictions)
+
+ self.evaluation_results[metric][category] = (total_score, len(self.data[category]["data"]))
+ self.evaluation_results[metric]["ALL"] += total_score * weight
+
+ def _calculate_loss_metrics(self, metric: str, category: str):
+ """Calculate perplexity."""
+ if metric == "perplexity":
+ weight = len(self.data[category]["data"]) / self.metric_total_length[metric]
+ losses = [min(sample["loss"]) for sample in self.data[category]["data"]]
+ perplexity = np.mean(np.exp(np.array(losses)))
+
+ self.evaluation_results["perplexity"][category] = (perplexity, len(self.data[category]["data"]))
+ self.evaluation_results["perplexity"]["ALL"] += perplexity * weight
+ elif metric == "ppl_score":
+ weight = len(self.data[category]["data"]) / self.metric_total_length[metric]
+ losses = [min(sample["loss"]) for sample in self.data[category]["data"]]
+ perplexity_score = np.mean(np.exp(-np.array(losses))) * 100
+
+ self.evaluation_results["ppl_score"][category] = (perplexity_score, len(self.data[category]["data"]))
+ self.evaluation_results["ppl_score"]["ALL"] += perplexity_score * weight
+ elif metric == "ppl_score_over_choices" and self.data[category]["inference_kwargs"]["all_classes"] is not None:
+ weight = len(self.data[category]["data"]) / self.metric_total_length[metric]
+ loss_over_choices = [sample["loss_over_choices"] for sample in self.data[category]["data"]]
+ perplexity_score_over_choices = np.mean(np.exp(-np.array(loss_over_choices))) * 100
+
+ self.evaluation_results["ppl_score_over_choices"][category] = (
+ perplexity_score_over_choices,
+ len(self.data[category]["data"]),
+ )
+ self.evaluation_results["ppl_score_over_choices"]["ALL"] += perplexity_score_over_choices * weight
+ elif metric == "per_byte_perplexity":
+ weight = len(self.data[category]["data"]) / self.metric_total_length[metric]
+ losses = [min(sample["loss_sum"]) for sample in self.data[category]["data"]]
+ perplexity = np.mean(np.exp(np.array(losses) / np.array(self.N_bytes[category])))
+
+ self.evaluation_results["per_byte_perplexity"][category] = perplexity
+ self.evaluation_results["per_byte_perplexity"]["ALL"] += perplexity * weight
+ elif metric == "per_byte_ppl_score":
+ weight = len(self.data[category]["data"]) / self.metric_total_length[metric]
+ losses = [min(sample["loss_sum"]) for sample in self.data[category]["data"]]
+ perplexity_score = np.mean(np.exp(-np.array(losses) / np.array(self.N_bytes[category]))) * 100
+
+ self.evaluation_results["per_byte_ppl_score"][category] = perplexity_score
+ self.evaluation_results["per_byte_ppl_score"]["ALL"] += perplexity_score * weight
+
+ def _evaluate(self):
+ """Calculate and return evaluation results"""
+
+ for metric in self.metrics:
+ pbar = tqdm.tqdm(
+ desc=f"{self.dataset_name}-{metric}-{self.model_name}", total=len(self.suggested_categories[metric])
+ )
+
+ if metric in LabelBasedMetrics:
+ for category in self.suggested_categories[metric]:
+ self._calculate_label_metrics(metric, category)
+ pbar.update(1)
+ elif metric in LossBasedMetrics:
+ for category in self.suggested_categories[metric]:
+ self._calculate_loss_metrics(metric, category)
+ pbar.update(1)
+ elif metric in CombinedMetrics:
+ for category in self.suggested_categories[metric]:
+ self._calculate_combined_metrics(metric, category)
+ pbar.update(1)
+ elif metric in OtherMetrics:
+ for category in self.suggested_categories[metric]:
+ self._calculate_other_metrics(metric, category)
+ pbar.update(1)
+
+ return self.evaluation_results
+
+ def get_evaluation_results(self, data: List[Dict], dataset_name: str, model_name: str, metrics: List[str]):
+ """
+ Evaluate inference data on the given metrics.
+
+ Args:
+ data: Data to be evaluated.
+ dataset_name: Name of the dataset
+ model_name: Name of the model
+ metrics: Metrics used to evaluate.
+
+ """
+ self.data = data
+ self.dataset_name = dataset_name
+ self.model_name = model_name
+ self.categories = list(data.keys())
+ self.metrics = metrics
+
+ self.evaluation_results = {
+ metric: {category: 0 for category in (["ALL"] + self.categories)} for metric in self.metrics
+ }
+
+ self.total_length = 0
+ self.total_single_choices = 0
+ for value in self.data.values():
+ self.total_length += len(value["data"])
+ if value["inference_kwargs"]["all_classes"] is not None:
+ self.total_single_choices += len(value["data"])
+
+ self.metric_total_length = {metric: 0 for metric in self.metrics}
+ self.suggested_categories = {metric: [] for metric in self.metrics}
+
+ for metric in self.metrics:
+ self.suggested_categories[metric] = metric_helper.metrics4subcategory[self.dataset_name][metric]
+ if "ALL" in self.suggested_categories[metric]:
+ self.suggested_categories[metric] = self.categories
+ self.metric_total_length[metric] = self.total_length
+ continue
+ for category in self.suggested_categories[metric]:
+ self.metric_total_length[metric] += len(self.data[category]["data"])
+
+ if "per_byte_perplexity" in self.metrics or "per_byte_ppl_score" in self.metrics:
+ self.N_bytes = {category: [] for category in self.categories}
+ for category in self.categories:
+ samples = self.data[category]["data"]
+ for sample in samples:
+ self.N_bytes[category].append(sample["byte_num"][0])
+
+ return self._evaluate()
diff --git a/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/metrics.py b/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/metrics.py
new file mode 100644
index 000000000000..914465478dec
--- /dev/null
+++ b/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/metrics.py
@@ -0,0 +1,623 @@
+# Code adapted from https://github.com/THUDM/LongBench/blob/main/metrics.py
+# Code adapted from https://github.com/hendrycks/math/blob/main/modeling/math_equivalence.py
+# Code adapted from https://github.com/ruixiangcui/AGIEval/blob/main/src/evaluation.py
+
+import difflib
+import re
+import string
+from collections import Counter
+
+import jieba
+from fuzzywuzzy import fuzz
+from rouge import Rouge
+
+metrics4subcategory = {
+ "pretrain": {
+ "perplexity": ["ALL"],
+ "ppl_score": ["ALL"],
+ "per_byte_perplexity": ["ALL"],
+ "per_byte_ppl_score": ["ALL"],
+ },
+ # The commented are non 4-choice questions.
+ "agieval": {
+ "combined_single_choice_accuracy": [
+ # "lsat-ar",
+ # "lsat-lr",
+ # "lsat-rc",
+ "logiqa-en",
+ "sat-math",
+ "sat-en",
+ # "aqua-rat",
+ "sat-en-without-passage",
+ "gaokao-english",
+ "logiqa-zh",
+ "gaokao-chinese",
+ "gaokao-geography",
+ "gaokao-history",
+ "gaokao-biology",
+ "gaokao-chemistry",
+ ],
+ "first_token_accuracy": [
+ # "lsat-ar",
+ # "lsat-lr",
+ # "lsat-rc",
+ "logiqa-en",
+ "sat-math",
+ "sat-en",
+ # "aqua-rat",
+ "sat-en-without-passage",
+ "gaokao-english",
+ "logiqa-zh",
+ "gaokao-chinese",
+ "gaokao-geography",
+ "gaokao-history",
+ "gaokao-biology",
+ "gaokao-chemistry",
+ ],
+ "single_choice_accuracy": [
+ # "lsat-ar",
+ # "lsat-lr",
+ # "lsat-rc",
+ "logiqa-en",
+ "sat-math",
+ "sat-en",
+ # "aqua-rat",
+ "sat-en-without-passage",
+ "gaokao-english",
+ "logiqa-zh",
+ "gaokao-chinese",
+ "gaokao-geography",
+ "gaokao-history",
+ "gaokao-biology",
+ "gaokao-chemistry",
+ ],
+ "multi_choice_accuracy": ["jec-qa-kd", "jec-qa-ca", "gaokao-physics", "gaokao-mathqa"],
+ "math_equivalence": ["gaokao-mathcloze", "math"],
+ "perplexity": ["ALL"],
+ "ppl_score_over_choices": [
+ "lsat-ar",
+ "lsat-lr",
+ "lsat-rc",
+ "logiqa-en",
+ "sat-math",
+ "sat-en",
+ "aqua-rat",
+ "sat-en-without-passage",
+ "gaokao-english",
+ "logiqa-zh",
+ "jec-qa-kd",
+ "jec-qa-ca",
+ "gaokao-chinese",
+ "gaokao-geography",
+ "gaokao-history",
+ "gaokao-biology",
+ "gaokao-chemistry",
+ "gaokao-physics",
+ "gaokao-mathqa",
+ ],
+ "ppl_score": ["ALL"],
+ },
+ "cmmlu": {
+ "first_token_accuracy": ["ALL"],
+ "single_choice_accuracy": ["ALL"],
+ "perplexity": ["ALL"],
+ "ppl_score_over_choices": ["ALL"],
+ "ppl_score": ["ALL"],
+ },
+ "gaokaobench": {
+ "combined_single_choice_accuracy": [
+ "English MCQs",
+ "Biology MCQs",
+ "Chemistry MCQs",
+ "History MCQs",
+ "Math I MCQs",
+ "Math II MCQs",
+ "Political Science MCQs",
+ ],
+ "first_token_accuracy": [
+ "English MCQs",
+ "Biology MCQs",
+ "Chemistry MCQs",
+ "History MCQs",
+ "Math I MCQs",
+ "Math II MCQs",
+ "Political Science MCQs",
+ ],
+ "single_choice_accuracy": [
+ "English MCQs",
+ "Biology MCQs",
+ "Chemistry MCQs",
+ "History MCQs",
+ "Math I MCQs",
+ "Math II MCQs",
+ "Political Science MCQs",
+ ],
+ "multi_choice_accuracy": [
+ "Chinese Lang and Usage MCQs",
+ "Chinese Modern Lit",
+ "English Fill in Blanks",
+ "English Reading Comp",
+ "Geography MCQs",
+ "Physics MCQs",
+ "English Cloze Test",
+ ],
+ "math_equivalence": ["Math I Fill-in-the-Blank", "Math II Fill-in-the-Blank"],
+ "rouge_score": ["English Language Cloze Passage"],
+ "rouge_zh_score": [
+ "Chinese Language Famous Passages and Sentences Dictation",
+ "Chemistry Open-ended Questions",
+ "History Open-ended Questions",
+ "Biology Open-ended Questions",
+ "Political Science Open-ended Questions",
+ "English Language Error Correction",
+ "Chinese Language Language and Writing Skills Open-ended Questions",
+ "Math II Open-ended Questions",
+ "Chinese Language Literary Text Reading",
+ "Chinese Language Ancient Poetry Reading",
+ "Chinese Language Classical Chinese Reading",
+ "Physics Open-ended Questions",
+ "Math I Open-ended Questions",
+ "Geography Open-ended Questions",
+ "Chinese Language Practical Text Reading",
+ ],
+ "perplexity": ["ALL"],
+ "ppl_score_over_choices": ["ALL"],
+ "ppl_score": ["ALL"],
+ },
+ "longbench": {
+ "f1_score": ["hotpotqa", "2wikimqa", "musique", "narrativeqa", "qasper", "multifieldqa_en", "triviaqa"],
+ "f1_zh_score": ["multifieldqa_zh"],
+ "rouge_score": ["gov_report", "qmsum", "multi_news", "samsum"],
+ "rouge_zh_score": ["dureader", "vcsum"],
+ "retrieval_score": ["passage_retrieval_en"],
+ "retrieval_zh_score": ["passage_retrieval_zh"],
+ "classification_score": ["trec", "lsht"],
+ "code_sim_score": ["lcc", "repobench-p"],
+ "count_score": ["passage_count"],
+ "perplexity": ["ALL"],
+ "ppl_score": ["ALL"],
+ },
+ "mmlu": {
+ "first_token_accuracy": ["ALL"],
+ "single_choice_accuracy": ["ALL"],
+ "accuracy": ["ALL"],
+ "perplexity": ["ALL"],
+ "ppl_score_over_choices": ["ALL"],
+ "ppl_score": ["ALL"],
+ },
+}
+
+
+def _fix_fracs(string):
+ substrs = string.split("\\frac")
+ new_str = substrs[0]
+ if len(substrs) > 1:
+ substrs = substrs[1:]
+ for substr in substrs:
+ new_str += "\\frac"
+ if substr[0] == "{":
+ new_str += substr
+ else:
+ try:
+ assert len(substr) >= 2
+ except:
+ return string
+ a = substr[0]
+ b = substr[1]
+ if b != "{":
+ if len(substr) > 2:
+ post_substr = substr[2:]
+ new_str += "{" + a + "}{" + b + "}" + post_substr
+ else:
+ new_str += "{" + a + "}{" + b + "}"
+ else:
+ if len(substr) > 2:
+ post_substr = substr[2:]
+ new_str += "{" + a + "}" + b + post_substr
+ else:
+ new_str += "{" + a + "}" + b
+ string = new_str
+ return string
+
+
+def _fix_a_slash_b(string):
+ if len(string.split("/")) != 2:
+ return string
+ a = string.split("/")[0]
+ b = string.split("/")[1]
+ try:
+ a = int(a)
+ b = int(b)
+ assert string == "{}/{}".format(a, b)
+ new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
+ return new_string
+ except:
+ return string
+
+
+def _remove_right_units(string):
+ # "\\text{ " only ever occurs (at least in the val set) when describing units
+ if "\\text{ " in string:
+ splits = string.split("\\text{ ")
+ assert len(splits) == 2
+ return splits[0]
+ else:
+ return string
+
+
+def _fix_sqrt(string):
+ if "\\sqrt" not in string:
+ return string
+ splits = string.split("\\sqrt")
+ new_string = splits[0]
+ for split in splits[1:]:
+ if split[0] != "{":
+ a = split[0]
+ new_substr = "\\sqrt{" + a + "}" + split[1:]
+ else:
+ new_substr = "\\sqrt" + split
+ new_string += new_substr
+ return new_string
+
+
+def _strip_string(string):
+ # linebreaks
+ string = string.replace("\n", "")
+ # print(string)
+
+ # remove inverse spaces
+ string = string.replace("\\!", "")
+ # print(string)
+
+ # replace \\ with \
+ string = string.replace("\\\\", "\\")
+ # print(string)
+
+ # replace tfrac and dfrac with frac
+ string = string.replace("tfrac", "frac")
+ string = string.replace("dfrac", "frac")
+ # print(string)
+
+ # remove \left and \right
+ string = string.replace("\\left", "")
+ string = string.replace("\\right", "")
+ # print(string)
+
+ # Remove circ (degrees)
+ string = string.replace("^{\\circ}", "")
+ string = string.replace("^\\circ", "")
+
+ # remove dollar signs
+ string = string.replace("\\$", "")
+
+ # remove units (on the right)
+ string = _remove_right_units(string)
+
+ # remove percentage
+ string = string.replace("\\%", "")
+ string = string.replace("\%", "")
+
+ # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
+ string = string.replace(" .", " 0.")
+ string = string.replace("{.", "{0.")
+ # if empty, return empty string
+ if len(string) == 0:
+ return string
+ if string[0] == ".":
+ string = "0" + string
+
+ # to consider: get rid of e.g. "k = " or "q = " at beginning
+ if len(string.split("=")) == 2:
+ if len(string.split("=")[0]) <= 2:
+ string = string.split("=")[1]
+
+ # fix sqrt3 --> sqrt{3}
+ string = _fix_sqrt(string)
+
+ # remove spaces
+ string = string.replace(" ", "")
+
+ # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
+ string = _fix_fracs(string)
+
+ # manually change 0.5 --> \frac{1}{2}
+ if string == "0.5":
+ string = "\\frac{1}{2}"
+
+ # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
+ string = _fix_a_slash_b(string)
+
+ return string
+
+
+def parse_math_answer(raw_string):
+ def remove_boxed(s):
+ left = "\\boxed{"
+ try:
+ assert s[: len(left)] == left
+ assert s[-1] == "}"
+ answer = s[len(left) : -1]
+ if "=" in answer:
+ answer = answer.split("=")[-1].lstrip(" ")
+ return answer
+ except:
+ return None
+
+ def last_boxed_only_string(string):
+ idx = string.rfind("\\boxed")
+ if idx < 0:
+ idx = string.rfind("\\fbox")
+ if idx < 0:
+ return None
+ i = idx
+ right_brace_idx = None
+ num_left_braces_open = 0
+ while i < len(string):
+ if string[i] == "{":
+ num_left_braces_open += 1
+ if string[i] == "}":
+ num_left_braces_open -= 1
+ if num_left_braces_open == 0:
+ right_brace_idx = i
+ break
+ i += 1
+
+ if right_brace_idx == None:
+ retval = None
+ else:
+ retval = string[idx : right_brace_idx + 1]
+
+ return retval
+
+ def get_answer_with_dollar_sign(s):
+ first_pattern = "\$(.*)\$"
+ last_match = None
+ matches = re.findall(first_pattern, s)
+ if matches:
+ last_match = matches[-1]
+ if "=" in last_match:
+ last_match = last_match.split("=")[-1].lstrip(" ")
+ return last_match
+
+ def get_answer_without_dollar_sign(s):
+ last_match = None
+ if "=" in s:
+ last_match = s.split("=")[-1].lstrip(" ").rstrip(".")
+ if "\\n" in last_match:
+ last_match = last_match.split("\\n")[0]
+ else:
+ pattern = "(?:\\$)?\d+(?:\.\d+)?(?![\w\d])"
+ matches = re.findall(pattern, s)
+ if matches:
+ last_match = matches[-1]
+ return last_match
+
+ if "\\boxed" in raw_string:
+ answer = remove_boxed(last_boxed_only_string(raw_string))
+ else:
+ answer = get_answer_with_dollar_sign(raw_string)
+ if not answer:
+ answer = get_answer_without_dollar_sign(raw_string)
+ return answer
+
+
+def math_equivalence(prediction, reference, **kwargs):
+ prediction = parse_math_answer(prediction)
+
+ if prediction is None and reference is None:
+ print("WARNING: Both None")
+ return False
+
+ if prediction is None or reference is None:
+ return False
+
+ try:
+ ss1 = _strip_string(prediction)
+ ss2 = _strip_string(reference)
+ return ss1 == ss2
+ except:
+ return prediction == reference
+
+
+def multi_choice_accuracy(prediction, reference, **kwargs):
+ # Only find uppercase letters not surrounded by lowercase letters
+ all_classes = kwargs.get("all_classes", None)
+ if all_classes:
+ pattern = f"(? highest_similarity:
+ highest_similarity = similarity
+ best_match = string
+ score = float(best_match == reference)
+ return score
+
+
+def rouge_score(prediction, reference, **kwargs):
+ rouge = Rouge()
+ try:
+ scores = rouge.get_scores([prediction], [reference], avg=True)
+ except:
+ return 0.0
+ return scores["rouge-l"]["f"]
+
+
+def rouge_zh_score(prediction, reference, **kwargs):
+ prediction = " ".join(list(jieba.cut(prediction, cut_all=False)))
+ reference = " ".join(list(jieba.cut(reference, cut_all=False)))
+ score = rouge_score(prediction, reference)
+ return score
+
+
+def _f1_score(prediction, reference, **kwargs):
+ common = Counter(prediction) & Counter(reference)
+ num_same = sum(common.values())
+ if num_same == 0:
+ return 0
+ precision = 1.0 * num_same / len(prediction)
+ recall = 1.0 * num_same / len(reference)
+ f1 = (2 * precision * recall) / (precision + recall)
+ return f1
+
+
+def f1_score(prediction, reference, **kwargs):
+ normalized_prediction = normalize_answer(prediction)
+ normalized_ground_truth = normalize_answer(reference)
+
+ prediction_tokens = normalized_prediction.split()
+ ground_truth_tokens = normalized_ground_truth.split()
+ return _f1_score(prediction_tokens, ground_truth_tokens)
+
+
+def f1_zh_score(prediction, reference, **kwargs):
+ prediction_tokens = list(jieba.cut(prediction, cut_all=False))
+ ground_truth_tokens = list(jieba.cut(reference, cut_all=False))
+ prediction_tokens = [normalize_zh_answer(token) for token in prediction_tokens]
+ ground_truth_tokens = [normalize_zh_answer(token) for token in ground_truth_tokens]
+ prediction_tokens = [token for token in prediction_tokens if len(token) > 0]
+ ground_truth_tokens = [token for token in ground_truth_tokens if len(token) > 0]
+ return _f1_score(prediction_tokens, ground_truth_tokens)
diff --git a/applications/ColossalEval/colossal_eval/evaluate/evaluator.py b/applications/ColossalEval/colossal_eval/evaluate/evaluator.py
new file mode 100644
index 000000000000..11e204b504c5
--- /dev/null
+++ b/applications/ColossalEval/colossal_eval/evaluate/evaluator.py
@@ -0,0 +1,110 @@
+import os
+from typing import Any, Dict, List
+
+import colossal_eval.evaluate.gpt_evaluate as gpt_evaluate
+
+from .utils import get_data_per_category
+
+
+class Evaluator(object):
+ """
+ A class named Evaluator includes GPT-3.5/GPT-4 evaluation
+
+ """
+
+ def __init__(
+ self,
+ params: Dict[str, Any],
+ battle_prompt: Dict[str, Any],
+ gpt_evaluation_prompt: Dict[str, Any],
+ gpt_model: str,
+ language: str,
+ gpt_with_reference: bool,
+ ) -> None:
+ self.params = params
+ self.battle_prompt = battle_prompt
+ self.gpt_evaluation_prompt = gpt_evaluation_prompt
+ self.gpt_model = gpt_model
+ self.language = language
+ self.gpt_with_reference = gpt_with_reference
+ self.gpt_evaluation_results = dict()
+ self.battle_results = []
+
+ def battle(self, answers1: List[Dict], answers2: List[Dict]) -> None:
+ """
+ Comparison between two models using GPT-4 as the reviewer.
+ """
+
+ self.battle_results = gpt_evaluate.battle(answers1, answers2, self.battle_prompt)
+
+ def evaluate(self, answers: List[Dict], targets: List[Dict], save_path: str, model_name: str) -> None:
+ """
+ A comprehensive evaluation of the answers from the model.
+ The function evaluates the model's performance from different perspectives
+ using GPT-3.5, GPT-4, and off-the-shelf evaluation metrics.
+
+ The metrics will be decided by the config file.
+
+ """
+
+ answers_per_category = get_data_per_category(answers, list(self.params.keys()))
+ targets_per_category = get_data_per_category(targets, list(self.params.keys()))
+
+ # gpt evaluation
+ for category in self.params:
+ if len(answers_per_category[category]) == 0:
+ print(f"Category {category} specified in your config doesn't have corresponding answers!")
+ continue
+
+ if self.params[category].get("GPT", None) is None:
+ continue
+
+ category_metrics = self.params[category]["GPT"]
+
+ prompt = self.gpt_evaluation_prompt.get(category, None)
+ if prompt is None:
+ print(f"No prompt for category {category}! Use prompt for category general now.")
+ prompt = self.gpt_evaluation_prompt["general"]
+
+ self.gpt_evaluation_results[category] = gpt_evaluate.evaluate(
+ answers_per_category[category],
+ prompt,
+ category_metrics,
+ category,
+ save_path,
+ model_name,
+ self.gpt_model,
+ self.language,
+ references=targets_per_category[category] if self.gpt_with_reference else None,
+ )
+
+ def save(self, path: str, model_name_list: List[str]) -> None:
+ """
+ Save evaluation results of GPT-3.5, GPT-4, and off-the-shelf evaluation metrics.
+
+ """
+
+ if len(model_name_list) == 2:
+ save_path = os.path.join(path, "gpt_evaluate", "battle_results")
+ gpt_evaluate.save_battle_results(self.battle_results, model_name_list[0], model_name_list[1], save_path)
+ else:
+ if self.gpt_evaluation_results:
+ # Save evaluation results for GPT evaluation metrics.
+ gpt_base_save_path = os.path.join(path, "gpt_evaluate", "gpt_evaluate_results")
+ gpt_evaluation_results_save_path = os.path.join(gpt_base_save_path, "evaluation_results")
+
+ all_evaluations = gpt_evaluate.save_gpt_evaluation_results(
+ model_name_list[0], self.gpt_evaluation_results, gpt_evaluation_results_save_path
+ )
+
+ # Start to calculate scores and save statistics.
+ gpt_evaluation_statistics_save_path = os.path.join(gpt_base_save_path, "evaluation_statistics")
+ gpt_evaluate.save_gpt_evaluation_statistics(
+ model_name_list[0], all_evaluations, gpt_evaluation_statistics_save_path
+ )
+
+ # Save charts and csv.
+ gpt_evaluation_analyses_save_path = os.path.join(gpt_base_save_path, "evaluation_analyses")
+ gpt_evaluate.analyze_gpt_evaluation_statistics(
+ gpt_evaluation_statistics_save_path, gpt_evaluation_analyses_save_path
+ )
diff --git a/applications/Chat/evaluate/gpt_evaluate.py b/applications/ColossalEval/colossal_eval/evaluate/gpt_evaluate.py
similarity index 80%
rename from applications/Chat/evaluate/gpt_evaluate.py
rename to applications/ColossalEval/colossal_eval/evaluate/gpt_evaluate.py
index f8cfb8d0f7e5..a0b1ed1143f0 100644
--- a/applications/Chat/evaluate/gpt_evaluate.py
+++ b/applications/ColossalEval/colossal_eval/evaluate/gpt_evaluate.py
@@ -11,23 +11,21 @@
import pandas as pd
import seaborn as sns
import tqdm
-from utils import jdump, jload
+from colossal_eval.utils import jdump, jload
ref_step_template = {
- "en":
- "Now please compare the answer with the {adjective} answer, determine whether the answer is able to achieve the same level of {metric}.\n\n",
- "cn":
- "请比较答案与上面的{adjective}答案,确定答案是否可以达到与该{adjective}答案同样水平的{metric}。\n\n"
+ "en": "Now please compare the answer with the {adjective} answer, determine whether the answer is able to achieve the same level of {metric}.\n\n",
+ "cn": "请比较答案与上面的{adjective}答案,确定答案是否可以达到与该{adjective}答案同样水平的{metric}。\n\n",
}
ref_answer_template_general = {
"en": "\nAn example answer with good quality is as follows:\n\n{answer}\n\n",
- "cn": "\n一个优质的示例答案如下:\n\n{answer}\n\n"
+ "cn": "\n一个优质的示例答案如下:\n\n{answer}\n\n",
}
ref_answer_template_correctness = {
"en": "\nA correct answer is as follows:\n\n{answer}\n\n",
- "cn": "\n标准答案如下:\n\n{answer}\n\n"
+ "cn": "\n标准答案如下:\n\n{answer}\n\n",
}
@@ -51,10 +49,7 @@ def get_battle_result(sys_prompt: str, user_prompt: str, id: int, max_tokens: in
response = openai.ChatCompletion.create(
model="gpt-4",
messages=[
- {
- "role": "system",
- "content": sys_prompt
- },
+ {"role": "system", "content": sys_prompt},
{
"role": "user",
"content": user_prompt,
@@ -106,7 +101,7 @@ def parse_battle_score(evaluation: str) -> List[float]:
return [float(sp[0]), float(sp[1])]
else:
raise Exception(f"Invalid score pair. Got {evaluation}.")
- except Exception as e:
+ except Exception:
return [-1, -1]
@@ -125,9 +120,6 @@ def battle(answer1: List[Dict], answer2: List[Dict], prompt_dict: Dict[str, Any]
assert len(answer1) == len(answer2)
- handles = []
- evaluation_file = []
-
total_len = len(answer1)
question_idx_list = list(range(total_len))
@@ -140,9 +132,12 @@ def battle(answer1: List[Dict], answer2: List[Dict], prompt_dict: Dict[str, Any]
assert answer1[i]["id"] == answer2[i]["id"]
answer_id = answer1[i]["id"]
- ques = answer1[i]["instruction"] if answer1[i][
- "input"] == "" else answer1[i]["instruction"] + " " + answer1[i]["input"]
- cat = answer1[i]["category"]
+ ques = (
+ answer1[i]["instruction"]
+ if answer1[i]["input"] == ""
+ else answer1[i]["instruction"] + " " + answer1[i]["input"]
+ )
+ answer1[i]["category"]
ans1 = answer1[i]["output"]
ans2 = answer2[i]["output"]
@@ -267,7 +262,11 @@ def reference_template(metric: str, language: str, reference: Dict[str, Any]) ->
step_to_add = ref_step_template[language]
- for_the_given_answer = "{metric} (1-5) (directly give the score for the given answer):" if language == "en" else "{metric} (1-5) (直接对给定答案打分)"
+ for_the_given_answer = (
+ "{metric} (1-5) (directly give the score for the given answer):"
+ if language == "en"
+ else "{metric} (1-5) (直接对给定答案打分)"
+ )
# adjective is used to describe the word "answer" in the prompt.
adjective = "example" if language == "en" else "示例"
@@ -280,8 +279,9 @@ def reference_template(metric: str, language: str, reference: Dict[str, Any]) ->
answer_to_add = ref_answer_template_correctness[language]
answer_to_add = answer_to_add.format(answer=reference["target"] if reference["target"] else reference["output"])
- step_to_add = step_to_add.format(metric=metric.lower(),
- adjective=adjective) + for_the_given_answer.format(metric=metric)
+ step_to_add = step_to_add.format(metric=metric.lower(), adjective=adjective) + for_the_given_answer.format(
+ metric=metric
+ )
return answer_to_add + step_to_add
@@ -329,7 +329,8 @@ def multiturn_chat_completion(user_messages: List[str], model: str, max_tokens:
for j in range(i):
messages_to_send.append(fill_in_message("user", user_messages[j]))
messages_to_send.append(
- fill_in_message("assistant", assistant_responses[j]["choices"][0]["message"]["content"]))
+ fill_in_message("assistant", assistant_responses[j]["choices"][0]["message"]["content"])
+ )
# Length of user messages == Length of assistant messages + 1
# Because we always expect the api to response
@@ -351,17 +352,19 @@ def multiturn_chat_completion(user_messages: List[str], model: str, max_tokens:
return assistant_responses[-1]
-def get_gpt_evaluation_without_logprobs(prompt: Dict[str, Any],
- inst: Dict[str, Any],
- metrics: List[str],
- language: str,
- reference: Dict[str, Any] = None,
- model: str = "gpt-3.5-turbo",
- max_tokens: int = 2048) -> Dict[str, Any]:
+def get_gpt_evaluation_without_logprobs(
+ prompt: Dict[str, Any],
+ inst: Dict[str, Any],
+ metrics: List[str],
+ language: str,
+ reference: Dict[str, Any] = None,
+ model: str = "gpt-3.5-turbo",
+ max_tokens: int = 2048,
+) -> Dict[str, Any]:
"""
Use chat models(gpt-3.5-turbo or gpt-4) to evaluate one model answer.
- Temperature is set to 0 to make the model more deterministic.
+ Temprature is set to 0 to make the model more deterministic.
Args:
prompt: a dictionary including prompt template, CoT and metrics.
@@ -378,7 +381,7 @@ def get_gpt_evaluation_without_logprobs(prompt: Dict[str, Any],
MAX_API_RETRY = 3
- question = (inst["instruction"] if inst["input"] == "" else inst["instruction"] + "\n" + inst["input"])
+ question = inst["instruction"] if inst["input"] == "" else inst["instruction"] + "\n" + inst["input"]
answer = inst["output"]
inst["evaluation"] = {}
@@ -398,12 +401,11 @@ def get_gpt_evaluation_without_logprobs(prompt: Dict[str, Any],
steps=prompt["CoT"][metric],
)
- if prompt_reference:
+ if prompt_reference and (reference["target"] or reference["output"]):
# Do a 2-round conversation
- response = multiturn_chat_completion([prompt_1st_round, prompt_reference],
- model,
- max_tokens=max_tokens,
- turns=2)
+ response = multiturn_chat_completion(
+ [prompt_1st_round, prompt_reference], model, max_tokens=max_tokens, turns=2
+ )
else:
response = multiturn_chat_completion([prompt_1st_round], model, max_tokens=max_tokens, turns=1)
@@ -427,15 +429,14 @@ def get_gpt_evaluation_without_logprobs(prompt: Dict[str, Any],
return inst
-def get_gpt_evaluation_with_logprobs(prompt: Dict[str, Any],
- inst: Dict[str, Any],
- metrics: List[str],
- max_tokens: int = 2048) -> Dict[str, Any]:
+def get_gpt_evaluation_with_logprobs(
+ prompt: Dict[str, Any], inst: Dict[str, Any], metrics: List[str], max_tokens: int = 2048
+) -> Dict[str, Any]:
"""
Use completion model(text-davinci-003) to evaluate one model answer.
Only completion models can return log probabilities.
- Temperature is set to 0 to make the model more deterministic.
+ Temprature is set to 0 to make the model more deterministic.
Args:
prompt: a dictionary including prompt template, CoT and metrics.
@@ -449,7 +450,7 @@ def get_gpt_evaluation_with_logprobs(prompt: Dict[str, Any],
MAX_API_RETRY = 3
- question = (inst["instruction"] if inst["input"] == "" else inst["instruction"] + "\n" + inst["input"])
+ question = inst["instruction"] if inst["input"] == "" else inst["instruction"] + "\n" + inst["input"]
answer = inst["output"]
inst["evaluation"] = {}
@@ -492,13 +493,17 @@ def get_gpt_evaluation_with_logprobs(prompt: Dict[str, Any],
return inst
-def evaluate(answers: List[Dict],
- prompt: Dict[str, Any],
- metrics: List[str],
- category: str,
- model: str,
- language: str,
- references: List[Dict] = None) -> List[Dict]:
+def evaluate(
+ answers: List[Dict],
+ prompt: Dict[str, Any],
+ metrics: List[str],
+ category: str,
+ save_path: str,
+ model_name: str,
+ model: str,
+ language: str,
+ references: List[Dict] = None,
+) -> List[Dict]:
"""
Use GPT models to evaluate model answers and save evaluation results.
@@ -522,6 +527,72 @@ def evaluate(answers: List[Dict],
metrics_str = ", ".join(x for x in metrics)
print(f"Category {category}'s metrics are {metrics_str}.")
+ gpt_base_save_path = os.path.join(save_path, "gpt_evaluate", "gpt_evaluate_results")
+ gpt_evaluation_results_save_path = os.path.join(gpt_base_save_path, "evaluation_results")
+ category_file = os.path.join(gpt_evaluation_results_save_path, model_name, f"{category}_evaluation_results.json")
+
+ if os.path.exists(category_file):
+ print(f"Evaluation results for category {category}, model {model_name} already exists.")
+ print("Skip evaluating.")
+
+ evaluations = jload(category_file)
+
+ retry = []
+ evaluations_copy = deepcopy(evaluations)
+
+ success = []
+ for idx, e in enumerate(evaluations_copy):
+ keys = list(e["evaluation"].keys())
+ for key in keys:
+ if e["evaluation"][key] == {}:
+ retry.append(e["id"])
+ print(f"Re-evaluate id {e['id']} now.")
+ break
+ if e["id"] not in retry:
+ success.append(e)
+
+ if len(retry) == 0:
+ evaluations.sort(key=lambda x: x["id"])
+ print(f"{category} done.")
+ return evaluations
+
+ with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
+ futures = []
+ for idx, inst in enumerate(answers):
+ if not inst["id"] in retry:
+ continue
+ # Completion models can return log probabilities.
+ if model == "text-davinci-003":
+ future = executor.submit(get_gpt_evaluation_with_logprobs, prompt, inst, metrics, 1)
+ else:
+ future = executor.submit(
+ get_gpt_evaluation_without_logprobs,
+ prompt,
+ inst,
+ metrics,
+ language,
+ reference=None if references is None else references[idx],
+ model=model,
+ max_tokens=1,
+ )
+
+ futures.append(future)
+
+ for future in tqdm.tqdm(
+ concurrent.futures.as_completed(futures),
+ desc=f"{category}: ",
+ total=len(futures),
+ ):
+ success.append(future.result())
+
+ success.sort(key=lambda x: x["id"])
+
+ print(f"Saving evaluation results for category {category}, model {model_name}.")
+
+ jdump(success, category_file)
+
+ return success
+
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
futures = []
for idx, inst in enumerate(answers):
@@ -529,21 +600,23 @@ def evaluate(answers: List[Dict],
if model == "text-davinci-003":
future = executor.submit(get_gpt_evaluation_with_logprobs, prompt, inst, metrics, 1)
else:
- future = executor.submit(get_gpt_evaluation_without_logprobs,
- prompt,
- inst,
- metrics,
- language,
- reference=None if references is None else references[idx],
- model=model,
- max_tokens=1)
+ future = executor.submit(
+ get_gpt_evaluation_without_logprobs,
+ prompt,
+ inst,
+ metrics,
+ language,
+ reference=None if references is None else references[idx],
+ model=model,
+ max_tokens=1,
+ )
futures.append(future)
for future in tqdm.tqdm(
- concurrent.futures.as_completed(futures),
- desc=f"{category}: ",
- total=len(futures),
+ concurrent.futures.as_completed(futures),
+ desc=f"{category}: ",
+ total=len(futures),
):
evaluations.append(future.result())
@@ -551,6 +624,10 @@ def evaluate(answers: List[Dict],
print(f"{category} done.")
+ print(f"Saving evaluation results for category {category}, model {model_name}.")
+
+ jdump(evaluations, category_file)
+
return evaluations
@@ -593,7 +670,7 @@ def calculate_scores_form_logprobs(logprobs: Dict[str, Any]) -> float:
def calculate_scores_form_response(response: str, evaluation: Dict[str, Any]) -> int:
"""
Calculate the score from the response returned by gpt-3.5-turbo or gpt-4.
- Different from text-davinci-003, this function directly calculates the score according to the plain response returned by gpt-3.5-turbo or gpt-4.
+ Different from text-davinci-003, this fuction directly calculates the score according to the plain response returned by gpt-3.5-turbo or gpt-4.
Although text-davinci-003 can return log probabilities, it costs ten times as much as gpt-3.5-turbo.
Args:
@@ -610,12 +687,13 @@ def calculate_scores_form_response(response: str, evaluation: Dict[str, Any]) ->
return int(results[0])
else:
raise Exception(f"Invalid score pair. Got {evaluation}.")
- except Exception as e:
+ except Exception:
return 0
-def save_gpt_evaluation_results(model_name: str, gpt_evaluation_results: Dict[str, Any],
- save_path: str) -> Dict[str, Any]:
+def save_gpt_evaluation_results(
+ model_name: str, gpt_evaluation_results: Dict[str, Any], save_path: str
+) -> Dict[str, Any]:
"""
Save evaluation results for different categories for one model.
@@ -667,10 +745,12 @@ def save_gpt_evaluation_statistics(model_name: str, evaluations: List[Dict], sav
scores[metric].append(0)
elif evaluation["evaluation"][metric]["logprobs"] is not None:
scores[metric].append(
- calculate_scores_form_logprobs(evaluation["evaluation"][metric]["logprobs"][0]))
+ calculate_scores_form_logprobs(evaluation["evaluation"][metric]["logprobs"][0])
+ )
else:
scores[metric].append(
- calculate_scores_form_response(evaluation["evaluation"][metric]["response"], evaluation))
+ calculate_scores_form_response(evaluation["evaluation"][metric]["response"], evaluation)
+ )
statistics = {}
for metric in metrics:
@@ -751,9 +831,9 @@ def analyze_gpt_evaluation_statistics(statistics_path: str, save_path: str) -> N
frame_all.to_csv(os.path.join(save_path, "gpt_evaluation_statistics.csv"))
for category in tqdm.tqdm(
- frame_per_category.keys(),
- desc=f"GPT evaluation: ",
- total=len(frame_per_category.keys()),
+ frame_per_category.keys(),
+ desc=f"GPT evaluation: ",
+ total=len(frame_per_category.keys()),
):
data = pd.DataFrame(frame_per_category[category])
diff --git a/applications/ColossalEval/colossal_eval/evaluate/utils.py b/applications/ColossalEval/colossal_eval/evaluate/utils.py
new file mode 100644
index 000000000000..69fec46705ab
--- /dev/null
+++ b/applications/ColossalEval/colossal_eval/evaluate/utils.py
@@ -0,0 +1,8 @@
+def get_data_per_category(data, categories):
+ data_per_category = {category: [] for category in categories}
+ for item in data:
+ category = item["category"]
+ if category in categories:
+ data_per_category[category].append(item)
+
+ return data_per_category
diff --git a/applications/ColossalEval/colossal_eval/models/__init__.py b/applications/ColossalEval/colossal_eval/models/__init__.py
new file mode 100644
index 000000000000..8f6c9b414145
--- /dev/null
+++ b/applications/ColossalEval/colossal_eval/models/__init__.py
@@ -0,0 +1,5 @@
+from .base import BaseModel
+from .chatglm import ChatGLM2Model, ChatGLMModel
+from .huggingface import HuggingFaceCausalLM, HuggingFaceModel
+
+__all__ = ["BaseModel", "HuggingFaceModel", "HuggingFaceCausalLM", "ChatGLMModel", "ChatGLM2Model"]
diff --git a/applications/ColossalEval/colossal_eval/models/base.py b/applications/ColossalEval/colossal_eval/models/base.py
new file mode 100644
index 000000000000..aae796c1d56e
--- /dev/null
+++ b/applications/ColossalEval/colossal_eval/models/base.py
@@ -0,0 +1,78 @@
+from abc import abstractclassmethod
+from typing import Dict, List
+
+from colossal_eval.utils import Conversation, prompt_templates
+
+from colossalai.logging import DistributedLogger
+
+
+class BaseModel:
+ """
+ Base class for model wrapper.
+
+ Args:
+ path: The path to the model.
+ model_max_length: The maximum sequence length of the model.
+ prompt_template: The model's prompt template.
+ batch_size: Batch size for inference.
+ logger: Logger for the model.
+ """
+
+ def __init__(
+ self,
+ path: str,
+ model_max_length: int = 2048,
+ prompt_template: Conversation = None,
+ batch_size: int = 1,
+ logger: DistributedLogger = None,
+ ):
+ self.path = path
+ self.model_max_length = model_max_length
+
+ if prompt_template:
+ self.prompt_template = prompt_template
+ else:
+ self.prompt_template = prompt_templates["plain"]
+
+ self.batch_size = batch_size
+ self.logger = logger
+
+ @abstractclassmethod
+ def inference(self, data: List[Dict]) -> None:
+ """
+ Infer the given data.
+ This function will call self.generate() to get model outputs and also self.model(input) to get logits.
+
+ Args:
+ data: The data for inference.
+ """
+
+ @abstractclassmethod
+ def generate(self, inputs: List[str], max_new_tokens: int) -> List[str]:
+ """
+ Generate results given a list of inputs.
+
+ Args:
+ inputs: A list of strings.
+ max_new_tokens: The maximum length of the output.
+
+ Returns:
+ A list of generated strings.
+ """
+
+ @abstractclassmethod
+ def get_loss(self, batch: List[str], batch_target: List[str]) -> List[float]:
+ """
+ Get loss given batch and batch with target.
+ Use their length difference after tokenization to mask the loss and only compute loss at target tokens.
+
+ Args:
+ batch: batch prompt without target answer.
+ batch_target: batch prompt with target answer.
+
+ Returns:
+ A list of loss.
+ """
+
+ def to(self, device):
+ self.model.to(device)
diff --git a/applications/ColossalEval/colossal_eval/models/chatglm.py b/applications/ColossalEval/colossal_eval/models/chatglm.py
new file mode 100644
index 000000000000..f293c4f699cd
--- /dev/null
+++ b/applications/ColossalEval/colossal_eval/models/chatglm.py
@@ -0,0 +1,303 @@
+import copy
+from typing import List
+
+import torch
+
+from .huggingface import HuggingFaceModel
+
+IGNORE_INDEX = -100
+
+
+class ChatGLMModel(HuggingFaceModel):
+ def _get_truncated_prompts(self, inputs: List[str], max_new_tokens: int) -> List[str]:
+ truncated_inputs = copy.deepcopy(inputs)
+ # Adapted from https://github.com/THUDM/ChatGLM-6B/blob/main/ptuning/main.py#L187
+ for i, input in enumerate(inputs):
+ a_ids = self.tokenizer.encode(text=input, truncation=False, add_special_tokens=False)
+
+ if len(a_ids) > self.model_max_length - max_new_tokens:
+ half = (self.model_max_length - max_new_tokens) // 2
+ prompt = self.tokenizer.decode(a_ids[:half], skip_special_tokens=True) + self.tokenizer.decode(
+ a_ids[-half:], skip_special_tokens=True
+ )
+ truncated_inputs[i] = prompt
+
+ return truncated_inputs
+
+ @torch.no_grad()
+ def get_loss(
+ self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool = False
+ ) -> List[List[float]]:
+ """
+ Calculate loss only on target tokens.
+
+ Args:
+ batch: A batch of prompt without target answer.
+ batch_target: A batch of target answer. Sometimes one question can have multiple target answers.
+
+ Returns:
+ Loss.
+
+ """
+
+ # We set max_new_tokens in self._get_truncated_prompts to 0 because we only need logits to calculate loss.
+ # We don't need to generate new tokens.
+ # Target answer's length is usually << model_max_length, but we still call it in case.
+ # We don't call self._get_truncated_prompts for batch_prompt because we need target answer's length first to reserve some space for target answer's tokens.
+ batch_target = [self._get_truncated_prompts(prompt_target, 0) for prompt_target in batch_target]
+
+ # Get the number of target answers for different questions
+ batch_target_nums = [len(prompt_target) for prompt_target in batch_target]
+
+ labels_list = []
+ input_ids_list = []
+
+ for input, targets in zip(batch_prompt, batch_target):
+ for target in targets:
+ # Adapted from https://github.com/THUDM/ChatGLM-6B/blob/main/ptuning/main.py#L187
+ # If there is no history, the prompt is just the query.
+ # We don't need to override self.generate() in ChatGLM-6B but need to override it in ChatGLM2-6B.
+ # See https://huggingface.co/THUDM/chatglm-6b/blob/main/modeling_chatglm.py#L1276
+ target_tokenized = self.tokenizer.encode(text=target, add_special_tokens=False)
+
+ # Get prompt with length model_max_length - len(target_tokenized).
+ # Reserve some space for target answer tokens using max_new_tokens.
+ # This will generate the correct start_idx and end_idx.
+ max_new_tokens = len(target_tokenized)
+
+ # Here 3 tokens are reserved for [gmask_id, bos_token, eos_id]. So we reserve max_new_tokens + 3 tokens.
+ # See https://huggingface.co/THUDM/chatglm-6b/blob/main/tokenization_chatglm.py#L323
+ prompt_with_correct_length = self._get_truncated_prompts([input], max_new_tokens + 3)[0]
+ input_tokenized = self.tokenizer.encode(prompt_with_correct_length, add_special_tokens=False)
+
+ input_ids = self.tokenizer.build_inputs_with_special_tokens(input_tokenized, target_tokenized)
+
+ context_length = input_ids.index(self.tokenizer.bos_token_id)
+ context_length - 1
+
+ target_ids = [IGNORE_INDEX] * len(input_ids)
+
+ # -1 is for eos_token, we don't want to calculate loss on eos token.
+ target_ids[-max_new_tokens - 1 : -1] = input_ids[-max_new_tokens - 1 : -1]
+
+ input_ids_list.append(torch.LongTensor(input_ids))
+ labels_list.append(torch.LongTensor(target_ids))
+
+ # Because of multiple target answers, the final batch size may be greater than self.batch_size.
+ # We will generate new batches.
+ losses = []
+ target_token_nums = []
+
+ batched_input_ids = [
+ input_ids_list[i : i + self.batch_size] for i in range(0, len(input_ids_list), self.batch_size)
+ ]
+ batched_labels = [labels_list[i : i + self.batch_size] for i in range(0, len(labels_list), self.batch_size)]
+
+ for batch_input_ids, batch_labels in zip(batched_input_ids, batched_labels):
+ losses_per_batch, target_token_num_per_batch = self._calculate_loss(batch_input_ids, batch_labels)
+ losses.extend(losses_per_batch)
+ target_token_nums.extend(target_token_num_per_batch)
+
+ start_indice = 0
+ losses_per_sample = []
+
+ target_token_nums_per_sample = []
+ for length in batch_target_nums:
+ losses_per_sample.append(losses[start_indice : start_indice + length])
+ target_token_nums_per_sample.append(target_token_nums[start_indice : start_indice + length])
+ start_indice += length
+
+ return losses_per_sample, target_token_nums_per_sample, None
+
+ def _calculate_loss(self, input_ids_list: List[torch.LongTensor], labels: List[torch.LongTensor]) -> List[float]:
+ """
+ Calculate loss only on target tokens.
+ Hugging Face generate() function can't return per sample loss.
+ It will only return the mean of the loss in a batch.
+ In torch.nn.CrossEntropyLoss(), reduction should be specified as "none" to get per sample loss.
+
+ Args:
+ input_ids_list: A batch of input token ids.
+ labels: A batch of labels.
+
+ Returns:
+ A list of loss.
+
+ """
+ input_ids = torch.nn.utils.rnn.pad_sequence(
+ input_ids_list, batch_first=True, padding_value=self.tokenizer.pad_token_id
+ ).to(torch.cuda.current_device())
+ labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX).to(
+ torch.cuda.current_device()
+ )
+
+ outputs = self.model(input_ids)[0]
+
+ shift_logits = outputs[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+
+ loss_fct = torch.nn.CrossEntropyLoss(reduction="none", ignore_index=IGNORE_INDEX)
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view(shift_labels.size())
+
+ lens = (labels != IGNORE_INDEX).sum(-1).cpu().numpy()
+
+ loss_sum = loss.sum(-1).to(torch.float32).cpu().detach().numpy()
+ return loss_sum.tolist(), lens.tolist()
+
+
+class ChatGLM2Model(ChatGLMModel):
+ def _get_truncated_prompts(self, inputs: List[str], max_new_tokens: int) -> List[str]:
+ truncated_inputs = copy.deepcopy(inputs)
+ # Adapted from https://github.com/THUDM/ChatGLM2-6B/blob/main/ptuning/main.py#L180
+ for i, input in enumerate(inputs):
+ a_ids = self.tokenizer.encode(text=input, add_special_tokens=True, truncation=False)
+
+ if len(a_ids) > self.model_max_length - max_new_tokens:
+ half = (self.model_max_length - max_new_tokens) // 2
+ prompt = self.tokenizer.decode(a_ids[:half], skip_special_tokens=True) + self.tokenizer.decode(
+ a_ids[-half:], skip_special_tokens=True
+ )
+ truncated_inputs[i] = prompt
+
+ return truncated_inputs
+
+ @torch.no_grad()
+ def generate(self, inputs: List[str], max_new_tokens: int, **kwargs) -> List[str]:
+ """Generate results given a list of inputs and get logits of the first new token over choices.
+
+ Args:
+ inputs: A list of strings.
+ max_new_tokens: Max new tokens for generation.
+ kwargs: Key arguments for generation
+
+ Returns:
+ A list of generated strings and logits over choices.
+
+ Note:
+ Currently the function only returns the logits of the first new token.
+ It is used for single choice question.
+ For multiple choices question, please avoid using the loss over choices.
+ You should set argument choices as None in self.inference().
+
+ """
+ # Follow the process of model.chat() method in modeling_chatglm2.py
+ # See https://huggingface.co/THUDM/chatglm2-6b/blob/main/modeling_chatglm.py#L1020
+ # See https://huggingface.co/THUDM/chatglm2-6b/blob/main/modeling_chatglm.py#L1001
+
+ query = []
+ for input in inputs:
+ prompt = self.tokenizer.build_prompt(input, None)
+ query.append(prompt)
+
+ truncated_query = self._get_truncated_prompts(query, max_new_tokens)
+
+ encoded_inputs = self.tokenizer(
+ truncated_query,
+ padding=True,
+ truncation=True,
+ return_tensors="pt",
+ max_length=self.model_max_length - max_new_tokens,
+ ).to(torch.cuda.current_device())
+
+ # Set output_scores=True to get prediction scores.
+ outputs = self.model.generate(
+ **encoded_inputs, max_new_tokens=max_new_tokens, return_dict_in_generate=True, output_scores=True, **kwargs
+ )
+
+ # We only need to decode predicted tokens.
+ sequences = outputs.sequences[:, encoded_inputs["input_ids"].shape[1] :]
+
+ scores = []
+ if self.indices_for_choices:
+ # If the question is a single-choice question, we will return the scores of specific indices for first predicted token.
+ # The indices are the tokenization results of the options for the single-choice question.
+ # For example, if the options of the question are A, B, C and D, we only returns scores at indices of A, B, C and D.
+ for option_indices in self.indices_for_choices:
+ scores.append(outputs.scores[0][:, option_indices].detach().cpu())
+
+ scores = torch.max(torch.stack(scores), dim=0)[0]
+
+ decoded_sequences = self.tokenizer.batch_decode(sequences, skip_special_tokens=True)
+
+ return decoded_sequences, scores
+
+ @torch.no_grad()
+ def get_loss(
+ self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool = False
+ ) -> List[List[float]]:
+ """
+ Calculate loss only on target tokens.
+
+ Args:
+ batch: A batch of prompt without target answer.
+ batch_target: A batch of target answer. Sometimes one question can have multiple target answers.
+
+ Returns:
+ Loss.
+
+ """
+
+ # We set max_new_tokens in self._get_truncated_prompts to 0 because we only need logits to calculate loss.
+ # We don't need to generate new tokens.
+ # Target answer's length is usually << model_max_length, but we still call it in case.
+ # We don't call self._get_truncated_prompts for batch_prompt because we need target answer's length first to reserve some space for target answer's tokens.
+ batch_target = [self._get_truncated_prompts(prompt_target, 0) for prompt_target in batch_target]
+
+ # Get the number of target answers for different questions
+ batch_target_nums = [len(prompt_target) for prompt_target in batch_target]
+
+ labels_list = []
+ input_ids_list = []
+
+ for input, targets in zip(batch_prompt, batch_target):
+ for target in targets:
+ # Adapted from https://github.com/THUDM/ChatGLM2-6B/blob/main/ptuning/main.py#L180
+ prompt = self.tokenizer.build_prompt(input, None)
+
+ target_tokenized = self.tokenizer.encode(
+ text=target, add_special_tokens=False, truncation=True, max_length=self.model_max_length
+ )
+
+ max_new_tokens = len(target_tokenized)
+ prompt_with_correct_length = self._get_truncated_prompts([prompt], max_new_tokens)[0]
+ input_tokenized = self.tokenizer.encode(
+ prompt_with_correct_length,
+ add_special_tokens=True,
+ truncation=True,
+ max_length=self.model_max_length,
+ )
+
+ input_ids = input_tokenized + target_tokenized + [self.tokenizer.eos_token_id]
+ target_ids = [IGNORE_INDEX] * len(input_ids)
+
+ # -1 is for "eos"
+ target_ids[-max_new_tokens - 1 : -1] = input_ids[-max_new_tokens - 1 : -1]
+
+ input_ids_list.append(torch.LongTensor(input_ids))
+ labels_list.append(torch.LongTensor(target_ids))
+
+ # Because of multiple target answers, the final batch size may be greater than self.batch_size.
+ # We will generate new batches.
+ losses = []
+ target_token_nums = []
+
+ batched_input_ids = [
+ input_ids_list[i : i + self.batch_size] for i in range(0, len(input_ids_list), self.batch_size)
+ ]
+ batched_labels = [labels_list[i : i + self.batch_size] for i in range(0, len(labels_list), self.batch_size)]
+
+ for batch_input_ids, batch_labels in zip(batched_input_ids, batched_labels):
+ losses_per_batch, target_token_num_per_batch = self._calculate_loss(batch_input_ids, batch_labels)
+ losses.extend(losses_per_batch)
+ target_token_nums.extend(target_token_num_per_batch)
+
+ start_indice = 0
+ losses_per_sample = []
+
+ target_token_nums_per_sample = []
+ for length in batch_target_nums:
+ losses_per_sample.append(losses[start_indice : start_indice + length])
+ target_token_nums_per_sample.append(target_token_nums[start_indice : start_indice + length])
+ start_indice += length
+
+ return losses_per_sample, target_token_nums_per_sample, None
diff --git a/applications/ColossalEval/colossal_eval/models/huggingface.py b/applications/ColossalEval/colossal_eval/models/huggingface.py
new file mode 100644
index 000000000000..9f785a6aa9d1
--- /dev/null
+++ b/applications/ColossalEval/colossal_eval/models/huggingface.py
@@ -0,0 +1,561 @@
+import copy
+import math
+from typing import Any, Dict, List, Optional, Tuple
+
+import numpy as np
+import torch
+from colossal_eval.utils import Conversation, get_batch_prompt, is_rank_0
+from peft import PeftModel
+from tqdm import tqdm
+from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer
+
+from colossalai.logging import DistributedLogger
+
+from .base import BaseModel
+
+IGNORE_INDEX = -100
+
+
+class HuggingFaceModel(BaseModel):
+ """
+ Model wrapper around HuggingFace AutoModel models.
+
+ Args:
+ path: The path to a HuggingFace model.
+ model_max_length: The maximum sequence length of the model.
+ tokenizer_path: The path to the tokenizer.
+ tokenizer_kwargs: Keyword arguments for the tokenizer.
+ peft_path: The name or path to the HuggingFace's PEFT model.
+ model_kwargs: Keyword arguments for the model.
+ prompt_template: The model's prompt template.
+ batch_size: Batch size for inference.
+ logger: Logger for the model.
+
+ """
+
+ def __init__(
+ self,
+ path: str,
+ model_max_length: int = 2048,
+ tokenizer_path: Optional[str] = None,
+ tokenizer_kwargs: dict = dict(),
+ peft_path: Optional[str] = None,
+ model_kwargs: Dict = None,
+ prompt_template: Conversation = None,
+ batch_size: int = 1,
+ logger: DistributedLogger = None,
+ ):
+ super().__init__(
+ path=path,
+ model_max_length=model_max_length,
+ prompt_template=prompt_template,
+ batch_size=batch_size,
+ logger=logger,
+ )
+ self._load_tokenizer(path=path, tokenizer_path=tokenizer_path, tokenizer_kwargs=tokenizer_kwargs)
+
+ self._load_model(path=path, model_kwargs=model_kwargs, peft_path=peft_path)
+
+ def _get_choices_indices(self, language: str):
+ """
+ Get indices for each choice
+
+ Some tokenizer will insert BOS if you don't specify add_special_tokens=False such as Llama-2.
+ The indices for choices may be different given the context. For example, for Llama-2 tokenizer, for Chinese context like "答案:{choice}", indices for choices A, B, C and D are 29909, 29933, 29907 and 29928, for English context like "Answer: {choice}", indices for choices A, B, C and D are 319, 350, 315 and 360.
+ print(self.tokenizer("答案:A")) to see
+ print(self.tokenizer("Answer: A")) to see
+
+ """
+
+ # A trick for get "all" tokens ids related to given choices.
+ self.indices_for_choices = [[] for _ in range(2)]
+ for choice in self.choices:
+ self.indices_for_choices[0].append(
+ self.tokenizer(f"Answer: {choice}", add_special_tokens=False).input_ids[-1]
+ )
+ self.indices_for_choices[1].append(self.tokenizer(f"答案:{choice}", add_special_tokens=False).input_ids[-1])
+
+ def _load_tokenizer(self, path: str, tokenizer_path: Optional[str], tokenizer_kwargs: dict):
+ """
+ Load tokenizer.
+
+ Args:
+ path: The path to the model. Usually it also serves as the path to the tokenizer.
+ tokenizer_path: The path to the tokenzier.
+ tokenizer_kwargs: Keyword arguments for the tokenizer.
+
+ """
+
+ if self.batch_size > 1:
+ tokenizer_kwargs.update({"padding_side": "left"})
+ tokenizer_kwargs.update({"truncation_side": "left"})
+
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path if tokenizer_path else path, **tokenizer_kwargs)
+
+ if self.tokenizer.pad_token_id is None:
+ self.logger.warning("pad_token_id is not set for the tokenizer. " "Using eos_token_id as pad_token_id.")
+ if self.tokenizer.eos_token:
+ self.tokenizer.pad_token = self.tokenizer.eos_token
+ elif self.tokenizer.eod_id:
+ # Qwen has an eod token "<|endoftext|>".
+ self.tokenizer.pad_token_id = self.tokenizer.eod_id
+
+ def _load_model(self, path: str, model_kwargs: dict, peft_path: Optional[str] = None):
+ """
+ Load model.
+
+ Args:
+ path: The path to the model.
+ model_kwargs: Keyword arguments for the model.
+ peft_path: The path to the peft model.
+
+ """
+
+ if "torch_dtype" in model_kwargs:
+ model_kwargs["torch_dtype"] = eval(model_kwargs["torch_dtype"])
+
+ model_kwargs.setdefault("torch_dtype", torch.float16)
+
+ self.model = AutoModel.from_pretrained(path, **model_kwargs).to(torch.cuda.current_device())
+ if peft_path is not None:
+ self.model = PeftModel.from_pretrained(self.model, peft_path, is_trainable=False)
+ self.model.eval()
+
+ def _calculate_loss(self, input_ids_list: List[torch.LongTensor], labels: List[torch.LongTensor]) -> Tuple[List]:
+ """
+ Calculate loss only on target tokens.
+ Hugging Face generate() function can't return per sample loss.
+ It will only return the mean of the loss in a batch.
+ In torch.nn.CrossEntropyLoss(), reduction should be specified as "none" to get per sample loss.
+
+ Args:
+ input_ids_list: A batch of input token ids.
+ labels: A batch of labels.
+
+ Returns:
+ A list of loss.
+
+ """
+ input_ids = torch.nn.utils.rnn.pad_sequence(
+ input_ids_list, batch_first=True, padding_value=self.tokenizer.pad_token_id
+ ).to(torch.cuda.current_device())
+ labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX).to(
+ torch.cuda.current_device()
+ )
+ attention_mask = input_ids.ne(self.tokenizer.pad_token_id).to(torch.cuda.current_device())
+
+ outputs = self.model(input_ids, attention_mask=attention_mask)[0]
+
+ shift_logits = outputs[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+
+ loss_fct = torch.nn.CrossEntropyLoss(reduction="none", ignore_index=IGNORE_INDEX)
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view(shift_labels.size())
+
+ lens = (labels != IGNORE_INDEX).sum(-1).cpu().numpy()
+
+ loss_sum = loss.sum(-1).to(torch.float32).cpu().detach().numpy()
+ return loss_sum.tolist(), lens.tolist()
+
+ def _get_truncated_prompts(self, inputs: List[str], max_new_tokens: int) -> List[str]:
+ """
+ Truncate the input sequence to fit model_max_length (we suggest truncate in the middle, since the left and right side may contain crucial instructions)
+ https://github.com/THUDM/LongBench/blob/main/pred.py#L16
+
+ Args:
+ inputs: A batch of input prompts.
+ max_new_tokens: Max new tokens for model to generate.
+
+ Returns:
+ Truncated prompts.
+
+ """
+
+ truncated_inputs = copy.deepcopy(inputs)
+ for i, input in enumerate(inputs):
+ tokenized_prompt = self.tokenizer(input, truncation=False, return_tensors="pt").input_ids[0]
+ if len(tokenized_prompt) > self.model_max_length - max_new_tokens:
+ half = (self.model_max_length - max_new_tokens) // 2
+ prompt = self.tokenizer.decode(
+ tokenized_prompt[:half], skip_special_tokens=True
+ ) + self.tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True)
+ truncated_inputs[i] = prompt
+
+ return truncated_inputs
+
+ def _get_input_ids_and_labels_pretrain(self, batch_prompt: List[str]) -> Tuple[List[torch.LongTensor]]:
+ """
+ Get input_ids and labels for pretrain data.
+ We only need batch_prompt because for pretain dataset, we don't need to predict new tokens.
+
+ Args:
+ batch_prompt: A batch of prompt.
+
+ Returns:
+ Input_ids and labels for the given batch.
+
+ """
+ input_ids_list = []
+ labels_list = []
+ bytes_list = []
+
+ for input in batch_prompt:
+ # Pretrain data tends to be very long, sometimes much larger than the model_max_length, we only tokenize 1/ratio of the data first to accelerate the tokenization process.
+ # Once the length of the result is greater or equal to model_max_length, we stop iterating on ratios and use the result as input_ids and labels.
+ # After all, the rest of the original string doesn't need to be tokenized at the first place.
+ ratio = [16, 8, 4, 2, 1]
+ tokenized = None
+ for r in ratio:
+ tokenized = self.tokenizer(
+ [input[0 : len(input) // r]], truncation=True, max_length=self.model_max_length, return_tensors="pt"
+ )
+ if tokenized.input_ids.size(1) >= self.model_max_length:
+ break
+
+ input_ids = copy.deepcopy(tokenized["input_ids"])[0]
+ target_ids = copy.deepcopy(input_ids)
+
+ string = self.tokenizer.decode(tokenized.input_ids[0], skip_special_tokens=True)
+
+ bytes_list.append(len(string.encode("utf-8")))
+
+ input_ids_list.append(input_ids)
+ labels_list.append(target_ids)
+
+ return input_ids_list, labels_list, bytes_list
+
+ def _get_input_ids_and_labels(
+ self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool
+ ) -> Tuple[List[torch.LongTensor]]:
+ """
+ Get input_ids and labels for the given data.
+
+ Args:
+ batch_prompt: A batch of prompt.
+ batch_target: A batch of target.
+
+ Returns:
+ Input_ids and labels for the given batch.
+
+ """
+ if pretrain:
+ return self._get_input_ids_and_labels_pretrain(batch_prompt)
+
+ input_ids_list = []
+ labels_list = []
+
+ for input, targets in zip(batch_prompt, batch_target):
+ for target in targets:
+ # TODO: Improve the labeling process. Should annotate the border by adding special tokens.
+ target_tokenized = self.tokenizer(
+ [target], truncation=True, max_length=self.model_max_length, return_tensors="pt"
+ )
+
+ # Get prompt with length model_max_length - len(target_tokenized).
+ # Reserve some space for target answer tokens using max_new_tokens.
+ # This will generate the correct start_idx and end_idx.
+ max_new_tokens = target_tokenized["input_ids"][0].size(0)
+ prompt_with_correct_length = self._get_truncated_prompts([input], max_new_tokens)[0]
+ input_tokenized = self.tokenizer(
+ [prompt_with_correct_length],
+ truncation=True,
+ max_length=self.model_max_length - max_new_tokens,
+ return_tensors="pt",
+ )
+
+ target_tokenized = self.tokenizer(
+ [prompt_with_correct_length + target],
+ truncation=True,
+ max_length=self.model_max_length,
+ return_tensors="pt",
+ )
+
+ start_idx = input_tokenized["input_ids"][0].size(0)
+ end_idx = target_tokenized["input_ids"][0].size(0)
+
+ # Sometimes if the target is only an option such as A, B, C and D, the length of input_tokenized is equal to the length of target_tokenized, so we need -1.
+ # This is caused by the different behavior of tokenizers.
+ # For example, the tokenizer for Baichuan and Llama will cause such problem in a plain prompt setting.
+ # The length of the tokenized sequences for prompt "Answer: " and "Answer: A" is the same.
+ # Baichuan: [29394, 31143, 31106] [29394, 31143, 703]
+ # Llama: [673, 29901, 29871] [673, 29901, 319]
+ # The length for sequence "prompt" and "prompt + A" is equal.
+ # For ChatGLM, the length of the tokenized sequences is different.
+ # ChatGLM: [16583, 12] [16583, 12, 167]
+
+ if start_idx == end_idx:
+ start_idx -= 1
+
+ input_ids = copy.deepcopy(target_tokenized["input_ids"])[0]
+ target_ids = copy.deepcopy(input_ids)
+
+ mask = torch.zeros_like(target_ids, dtype=torch.bool)
+ mask[start_idx:end_idx] = True
+
+ target_ids[~mask] = IGNORE_INDEX
+
+ input_ids_list.append(input_ids)
+ labels_list.append(target_ids)
+
+ return input_ids_list, labels_list, None
+
+ def inference(self, data: List[Dict], inference_kwargs: Dict[str, Any], debug: bool = False) -> List[Dict]:
+ """
+ Infer the given data.
+ This function will call self.generate() to get model outputs and also self.model() to get logits.
+
+ Args:
+ data: The data for inference.
+ inference_kwargs: Arguments for inference.
+ debug: Whether to display generated prompt for debugging.
+
+ Returns:
+ Inference results.
+
+ """
+ calculate_loss = inference_kwargs["calculate_loss"]
+ classes = inference_kwargs["all_classes"]
+ language = inference_kwargs["language"]
+ pretrain = inference_kwargs["pretrain"]
+ max_new_tokens = inference_kwargs["max_new_tokens"]
+ few_shot_data = inference_kwargs.get("few_shot_data", None)
+
+ # Some classification questions' options are texts not a single letter such as A, B, C and D.
+ # If the text length is greater than 1, we won't calculate loss over choices.
+ if classes is not None and any(len(c) > 1 for c in classes):
+ classes = None
+
+ self.choices = classes
+ self.indices_for_choices = None
+ if self.choices:
+ # Get indices for each choice
+ self._get_choices_indices(language)
+
+ self.str_label_map = {choice: idx for idx, choice in enumerate(self.choices)}
+
+ bar = tqdm(
+ range(math.ceil(len(data) / self.batch_size)),
+ desc=f"{data[0]['dataset']}-{data[0]['category']} Inference steps",
+ disable=not is_rank_0(),
+ )
+ loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
+
+ answers = copy.deepcopy(data)
+ for i in range(0, len(data), self.batch_size):
+ batch = data[i : i + self.batch_size]
+ batch_prompt, batch_target = get_batch_prompt(
+ self.prompt_template, batch, few_shot_data, self.tokenizer, language, self.model_max_length
+ )
+
+ if is_rank_0() and debug and i == 0:
+ self.logger.info(
+ f"Inference arguments for dataset {data[0]['dataset']} category {data[0]['category']} is:\n{inference_kwargs}"
+ )
+ self.logger.info("-" * 120)
+ self.logger.info("An example prompt and prompt with target is:")
+ self.logger.info("-" * 120)
+ self.logger.info(batch_prompt[0])
+ self.logger.info("-" * 120)
+ self.logger.info(batch_prompt[0] + batch_target[0][0])
+
+ if not pretrain:
+ batch_decodes, scores = self.generate(batch_prompt, max_new_tokens)
+
+ if calculate_loss:
+ batch_losses, batch_target_token_nums, batch_bytes_nums = self.get_loss(
+ batch_prompt, batch_target, pretrain
+ )
+
+ probs = []
+ if self.indices_for_choices:
+ scores = scores.to(torch.float32)
+ # If we have indices_for_choices(must be single-choice question), there will be only one target answer for one data sample.
+ # Otherwise this will violate the single-choice setting.
+
+ if calculate_loss:
+ labels = [self.str_label_map[answers[i + j]["target"]] for j in range(len(batch_decodes))]
+
+ loss_over_choices = loss_fct(scores, torch.tensor(labels, dtype=torch.long)).numpy().tolist()
+
+ probs = torch.nn.functional.softmax(scores, dim=-1).numpy().tolist()
+ probs = [
+ {choice: probs[i][self.str_label_map[choice]] for choice in self.choices} for i in range(len(probs))
+ ]
+
+ for j in range(len(batch_prompt)):
+ if not pretrain:
+ answers[i + j]["output"] = batch_decodes[j].strip()
+
+ if isinstance(scores, torch.Tensor):
+ answers[i + j]["softmax_over_choices"] = probs[j]
+
+ if calculate_loss:
+ answers[i + j]["loss_over_choices"] = loss_over_choices[j]
+
+ if calculate_loss:
+ answers[i + j]["loss"] = (np.array(batch_losses[j]) / np.array(batch_target_token_nums[j])).tolist()
+
+ # loss_sum is specially used for pertrain dataset for calculating per-byte-perplexity.
+ # However, loss (which is per sample loss) suffices for most cases.
+ answers[i + j]["loss_sum"] = batch_losses[j]
+ answers[i + j]["token_num"] = batch_target_token_nums[j]
+
+ if batch_bytes_nums:
+ answers[i + j]["byte_num"] = batch_bytes_nums[j]
+
+ bar.update()
+
+ return answers
+
+ @torch.no_grad()
+ def generate(self, inputs: List[str], max_new_tokens: int, **kwargs) -> List[str]:
+ """Generate results given a list of inputs and get logits of the first new token over choices.
+
+ Args:
+ inputs: A list of strings.
+ max_new_tokens: Max new tokens for generation.
+ kwargs: Key arguments for generation
+
+ Returns:
+ A list of generated strings and logits over choices.
+
+ Note:
+ Currently the function only returns the logits of the first new token.
+ It is used for single choice question.
+ For multiple choices question, please avoid using the loss over choices.
+ You should set argument choices as None in self.inference().
+
+ """
+ truncated_inputs = self._get_truncated_prompts(inputs, max_new_tokens)
+
+ encoded_inputs = self.tokenizer(
+ truncated_inputs,
+ padding=True,
+ truncation=True,
+ return_tensors="pt",
+ return_token_type_ids=False,
+ max_length=self.model_max_length - max_new_tokens,
+ ).to(torch.cuda.current_device())
+
+ # Set output_scores=True to get prediction scores.
+ outputs = self.model.generate(
+ **encoded_inputs, max_new_tokens=max_new_tokens, return_dict_in_generate=True, output_scores=True, **kwargs
+ )
+
+ # We only need to decode predicted tokens.
+ sequences = outputs.sequences[:, encoded_inputs["input_ids"].shape[1] :]
+
+ scores = []
+ if self.indices_for_choices:
+ # If the question is a single-choice question, we will return the scores of specific indices for first predicted token.
+ # The indices are the tokenization results of the options for the single-choice question.
+ # For example, if the options of the question are A, B, C and D, we only returns scores at indices of A, B, C and D.
+ for option_indices in self.indices_for_choices:
+ scores.append(outputs.scores[0][:, option_indices].detach().cpu())
+
+ scores = torch.max(torch.stack(scores), dim=0)[0]
+
+ decoded_sequences = self.tokenizer.batch_decode(sequences, skip_special_tokens=True)
+
+ return decoded_sequences, scores
+
+ @torch.no_grad()
+ def get_loss(self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool) -> List[List[float]]:
+ """
+ Calculate loss only on target tokens.
+
+ Args:
+ batch: A batch of prompt without target answer.
+ batch_target: A batch of target answer. Sometimes one question can have multiple target answers.
+
+ Returns:
+ Loss.
+
+ """
+
+ # We set max_new_tokens in self._get_truncated_prompts to 0 because we only need logits to calculate loss.
+ # We don't need to generate new tokens.
+ # Target answer's length is usually << model_max_length, but we still call it in case.
+ # We don't call self._get_truncated_prompts for batch_prompt because we need target answer's length first to reserve some space for target answer's tokens.
+ if not pretrain:
+ batch_target = [self._get_truncated_prompts(prompt_target, 0) for prompt_target in batch_target]
+
+ # Get the number of target answers for different questions
+ batch_target_nums = [len(prompt_target) for prompt_target in batch_target]
+
+ input_ids_list, labels_list, bytes_list = self._get_input_ids_and_labels(batch_prompt, batch_target, pretrain)
+
+ # Because of multiple target answers, the final batch size may be greater than self.batch_size.
+ # We will generate new batches.
+ losses = []
+ target_token_nums = []
+
+ batched_input_ids = [
+ input_ids_list[i : i + self.batch_size] for i in range(0, len(input_ids_list), self.batch_size)
+ ]
+ batched_labels = [labels_list[i : i + self.batch_size] for i in range(0, len(labels_list), self.batch_size)]
+
+ for batch_input_ids, batch_labels in zip(batched_input_ids, batched_labels):
+ losses_per_batch, target_token_num_per_batch = self._calculate_loss(batch_input_ids, batch_labels)
+ losses.extend(losses_per_batch)
+ target_token_nums.extend(target_token_num_per_batch)
+
+ start_indice = 0
+ losses_per_sample = []
+
+ target_token_nums_per_sample = []
+ bytes_nums_per_sample = []
+ for length in batch_target_nums:
+ losses_per_sample.append(losses[start_indice : start_indice + length])
+ target_token_nums_per_sample.append(target_token_nums[start_indice : start_indice + length])
+
+ if bytes_list:
+ bytes_nums_per_sample.append(bytes_list[start_indice : start_indice + length])
+
+ start_indice += length
+
+ if bytes_list:
+ return losses_per_sample, target_token_nums_per_sample, bytes_nums_per_sample
+
+ return losses_per_sample, target_token_nums_per_sample, None
+
+
+class HuggingFaceCausalLM(HuggingFaceModel):
+ """
+ Model wrapper around HuggingFace AutoModelForCausalLM models.
+
+ Args:
+ path: The path to a HuggingFace model.
+ model_max_length: The maximum sequence length of the model.
+ tokenizer_path: The path to the tokenizer.
+ tokenizer_kwargs: Keyword arguments for the tokenizer.
+ peft_path: The name or path to the HuggingFace's PEFT model.
+ model_kwargs: Keyword arguments for the model.
+ prompt_template: The model's prompt template.
+ batch_size: Batch size for inference.
+ logger: Logger for the model.
+
+ """
+
+ def _load_model(self, path: str, model_kwargs: dict, peft_path: Optional[str] = None):
+ """
+ Load model.
+
+ Args:
+ path: The path to the model.
+ model_kwargs: Keyword arguments for the model.
+ peft_path: The path to the peft model.
+
+ """
+
+ if "torch_dtype" in model_kwargs:
+ model_kwargs["torch_dtype"] = eval(model_kwargs["torch_dtype"])
+
+ if "config" in model_kwargs:
+ model_kwargs["config"] = AutoConfig.from_pretrained(model_kwargs["config"])
+
+ model_kwargs.setdefault("torch_dtype", torch.float16)
+ self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs).to(torch.cuda.current_device())
+ if peft_path is not None:
+ self.model = PeftModel.from_pretrained(self.model, peft_path, is_trainable=False)
+ self.model.eval()
diff --git a/applications/ColossalEval/colossal_eval/utils/__init__.py b/applications/ColossalEval/colossal_eval/utils/__init__.py
new file mode 100644
index 000000000000..d5ee6e13b747
--- /dev/null
+++ b/applications/ColossalEval/colossal_eval/utils/__init__.py
@@ -0,0 +1,4 @@
+from .conversation import Conversation, get_batch_prompt, prompt_templates
+from .utilities import get_json_list, is_rank_0, jdump, jload
+
+__all__ = ["Conversation", "prompt_templates", "get_batch_prompt", "is_rank_0", "jload", "jdump", "get_json_list"]
diff --git a/applications/ColossalEval/colossal_eval/utils/conversation.py b/applications/ColossalEval/colossal_eval/utils/conversation.py
new file mode 100644
index 000000000000..6c096a8523c0
--- /dev/null
+++ b/applications/ColossalEval/colossal_eval/utils/conversation.py
@@ -0,0 +1,231 @@
+import dataclasses
+from enum import Enum, auto
+from typing import Dict, List, Optional, Tuple
+
+from transformers import AutoTokenizer
+
+
+class SeparatorStyle(Enum):
+ ADD_BOS_EOS_TOKEN = auto()
+ ALPACA = auto()
+ PLAIN = auto()
+
+
+@dataclasses.dataclass
+class Conversation:
+ system: str
+ roles: List[str]
+ messages: List[List[str]]
+ offset: int
+ sep_style: SeparatorStyle = SeparatorStyle.ADD_BOS_EOS_TOKEN
+ sep: str = ""
+
+ def clear(self):
+ self.messages = []
+
+ def get_prompt(self):
+ if self.sep_style == SeparatorStyle.ADD_BOS_EOS_TOKEN:
+ ret = self.system
+ for role, message in self.messages:
+ if message:
+ ret += role + ": " + "" + message + self.sep
+ else:
+ ret += role + ": " + ""
+ return ret
+ elif self.sep_style == SeparatorStyle.ALPACA:
+ ret = self.system + self.sep
+ for role, message in self.messages:
+ if message:
+ ret += role + ":\n" + message + self.sep
+ else:
+ ret += role + ":"
+ return ret
+ elif self.sep_style == SeparatorStyle.PLAIN:
+ ret = self.system
+ for role, message in self.messages:
+ if message:
+ ret += message
+ else:
+ ret += ""
+ return ret
+ else:
+ raise ValueError(f"Invalid style: {self.sep_style}")
+
+ def get_prompt_with_target(self, target):
+ prompt = self.get_prompt()
+ prompt_with_target = []
+
+ # Some dataset provides multiple target answers.
+ # This will make it difficult when we calculate loss.
+ # We convert target into list[str] first if the question only has one target answer.
+ target_answers = []
+ if isinstance(target, str):
+ target_answers = [target]
+ else:
+ target_answers = target
+
+ for target_answer in target_answers:
+ if self.sep_style == SeparatorStyle.ADD_BOS_EOS_TOKEN:
+ prompt_with_target.append(prompt + target_answer)
+ elif self.sep_style == SeparatorStyle.ALPACA:
+ prompt_with_target.append(prompt + target_answer)
+ elif self.sep_style == SeparatorStyle.PLAIN:
+ prompt_with_target.append(prompt + target_answer)
+ else:
+ raise ValueError(f"Invalid style: {self.sep_style}")
+
+ return prompt_with_target
+
+ def save_prompt(self):
+ if self.sep_style == SeparatorStyle.ADD_BOS_EOS_TOKEN:
+ ret = self.system
+ for role, message in self.messages:
+ if message:
+ ret += role + ": " + "" + message + "\n"
+ else:
+ ret += role + ": " + ""
+ return ret
+ else:
+ raise ValueError(f"Invalid style: {self.sep_style}")
+
+ def append_message(self, role, message):
+ self.messages.append([role, message])
+
+ def copy(self):
+ return Conversation(
+ system=self.system,
+ roles=self.roles,
+ messages=[[x, y] for x, y in self.messages],
+ offset=self.offset,
+ sep_style=self.sep_style,
+ sep=self.sep,
+ )
+
+ def dict(self):
+ return {
+ "system": self.system,
+ "roles": self.roles,
+ "messages": self.messages,
+ "offset": self.offset,
+ "sep_style": self.sep_style,
+ "sep": self.sep,
+ }
+
+
+def get_few_shot_prefix(
+ conv: Conversation, few_shot_data: List[str], tokenizer: Optional[AutoTokenizer], language: str, max_tokens: int
+) -> str:
+ """
+ Get few shot prefix.
+
+ Args:
+ conv: Conversation template.
+ few_shot_examples: Few shot examples to generate few shot prompt prefix.
+
+ Returns:
+ Few shot prompt prefix.
+ """
+
+ if language == "English":
+ few_shot_prefix = f"The following are answers for questions in an exam.\n\n"
+ elif language == "Chinese":
+ few_shot_prefix = f"以下是考试中各个问题的答案。\n\n"
+
+ output = None
+ for i in range(len(few_shot_data)):
+ few_shot_prefix = few_shot_prefix + few_shot_data[i] + "\n\n"
+
+ if len(tokenizer([few_shot_prefix]).input_ids[0]) <= max_tokens:
+ output = few_shot_prefix
+ else:
+ break
+
+ return output if output is not None else few_shot_prefix
+
+
+def get_batch_prompt(
+ conv: Conversation,
+ batch: List[Dict],
+ few_shot_data: List[str],
+ tokenizer: Optional[AutoTokenizer],
+ language: Optional[str],
+ model_max_length: Optional[int],
+) -> Tuple[List[Dict], List[Dict]]:
+ """
+ Get batch prompt and target.
+
+ Args:
+ conv: Conversation template.
+ batch: Batch data to generate prompt from.
+ few_shot_data: Few shot data to generate few shot prompt prefix.
+
+ Returns:
+ Tuple containg batch prompt and target.
+
+ """
+
+ batch_prompt = []
+ batch_target = []
+
+ if isinstance(batch[0], dict):
+ for b in batch:
+ few_shot_prefix = ""
+ if few_shot_data is not None:
+ # For few-shot, only need input. Otherwise use instruction (in AGIEval).
+ query_text = b["input"] if b.get("input", "") != "" else b["instruction"]
+
+ if isinstance(b["target"], str):
+ zero_shot_prompt = query_text + b["target"]
+ max_tokens = model_max_length - len(tokenizer([zero_shot_prompt]).input_ids[0])
+ else:
+ raise Exception("When using few-shot, target answer should be a string.")
+
+ few_shot_prefix = get_few_shot_prefix(conv, few_shot_data, tokenizer, language, max_tokens)
+ else:
+ query_text = b["instruction"] + "\n\n" + b["input"] if b.get("input", "") != "" else b["instruction"]
+
+ conv.append_message(conv.roles[0], few_shot_prefix + query_text)
+ conv.append_message(conv.roles[1], None)
+
+ batch_prompt.append(conv.get_prompt())
+
+ target = b["target"]
+ if isinstance(b["target"], str):
+ target = [target]
+
+ batch_target.append(target)
+
+ conv.clear()
+
+ return batch_prompt, batch_target
+
+
+conv_coati = Conversation(
+ system="A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
+ roles=("Human", "Assistant"),
+ messages=[],
+ offset=0,
+ sep_style=SeparatorStyle.ADD_BOS_EOS_TOKEN,
+ sep="",
+)
+
+conv_alpaca = Conversation(
+ system="Below is an instruction that describes a task. Write a response that appropriately completes the request.",
+ roles=("### Instruction", "### Response"),
+ messages=[],
+ offset=0,
+ sep_style=SeparatorStyle.ALPACA,
+ sep="\n\n",
+)
+
+conv_plain = Conversation(
+ system="",
+ roles=("", ""),
+ messages=[],
+ offset=0,
+ sep_style=SeparatorStyle.PLAIN,
+ sep="",
+)
+
+prompt_templates = {"coati": conv_coati, "alpaca": conv_alpaca, "plain": conv_plain}
diff --git a/applications/ColossalEval/colossal_eval/utils/utilities.py b/applications/ColossalEval/colossal_eval/utils/utilities.py
new file mode 100644
index 000000000000..4eda07907495
--- /dev/null
+++ b/applications/ColossalEval/colossal_eval/utils/utilities.py
@@ -0,0 +1,62 @@
+import io
+import json
+import os
+
+import torch.distributed as dist
+
+
+def is_rank_0() -> bool:
+ return not dist.is_initialized() or dist.get_rank() == 0
+
+
+def _make_w_io_base(f, mode: str):
+ if not isinstance(f, io.IOBase):
+ f_dirname = os.path.dirname(f)
+ if f_dirname != "":
+ os.makedirs(f_dirname, exist_ok=True)
+ f = open(f, mode=mode, encoding="utf-8")
+ return f
+
+
+def _make_r_io_base(f, mode: str):
+ if not isinstance(f, io.IOBase):
+ f = open(f, mode=mode, encoding="utf-8")
+ return f
+
+
+def jdump(obj, f, mode="w", indent=4, default=str):
+ """
+ Dump a str or dictionary to a file in json format.
+
+ Args:
+ obj: An object to be written.
+ f: A string path to the location on disk.
+ mode: Mode for opening the file.
+ indent: Indent for storing json dictionaries.
+ default: A function to handle non-serializable entries; defaults to `str`.
+
+ """
+ f = _make_w_io_base(f, mode)
+ if isinstance(obj, (dict, list)):
+ json.dump(obj, f, indent=indent, default=default, ensure_ascii=False)
+ elif isinstance(obj, str):
+ f.write(obj)
+ else:
+ raise ValueError(f"Unexpected type: {type(obj)}")
+ f.close()
+
+
+def jload(f, mode="r"):
+ """Load a .json file into a dictionary."""
+ f = _make_r_io_base(f, mode)
+ jdict = json.load(f)
+ f.close()
+ return jdict
+
+
+def get_json_list(file_path):
+ with open(file_path, "r") as f:
+ json_list = []
+ for line in f:
+ json_list.append(json.loads(line if line != "null" else line))
+ return json_list
diff --git a/applications/ColossalEval/configs/gpt_evaluation/config/config_cn.json b/applications/ColossalEval/configs/gpt_evaluation/config/config_cn.json
new file mode 100644
index 000000000000..d7c864881008
--- /dev/null
+++ b/applications/ColossalEval/configs/gpt_evaluation/config/config_cn.json
@@ -0,0 +1,44 @@
+{
+ "language": "cn",
+ "category": {
+ "brainstorming": {
+ "GPT": [
+ "language organization",
+ "relevance",
+ "creativity",
+ "practicality",
+ "reasonableness"
+ ]
+ },
+ "chat": {
+ "GPT": [
+ "language organization",
+ "naturalness",
+ "engagingness",
+ "fidelity"
+ ]
+ },
+ "generation": {
+ "GPT": [
+ "language organization",
+ "relevance",
+ "diversity"
+ ]
+ },
+ "open_qa": {
+ "GPT": [
+ "language organization",
+ "relevance",
+ "correctness"
+ ]
+ },
+ "roleplay": {
+ "GPT": [
+ "language organization",
+ "relevance",
+ "fidelity",
+ "creativity"
+ ]
+ }
+ }
+}
diff --git a/applications/ColossalEval/configs/gpt_evaluation/config/config_en.json b/applications/ColossalEval/configs/gpt_evaluation/config/config_en.json
new file mode 100644
index 000000000000..6ebe3996b1cf
--- /dev/null
+++ b/applications/ColossalEval/configs/gpt_evaluation/config/config_en.json
@@ -0,0 +1,44 @@
+{
+ "language": "en",
+ "category": {
+ "brainstorming": {
+ "GPT": [
+ "language organization",
+ "relevance",
+ "creativity",
+ "practicality",
+ "reasonableness"
+ ]
+ },
+ "chat": {
+ "GPT": [
+ "language organization",
+ "naturalness",
+ "engagingness",
+ "fidelity"
+ ]
+ },
+ "generation": {
+ "GPT": [
+ "language organization",
+ "relevance",
+ "diversity"
+ ]
+ },
+ "open_qa": {
+ "GPT": [
+ "language organization",
+ "relevance",
+ "correctness"
+ ]
+ },
+ "roleplay": {
+ "GPT": [
+ "language organization",
+ "relevance",
+ "fidelity",
+ "creativity"
+ ]
+ }
+ }
+}
diff --git a/applications/ColossalEval/configs/gpt_evaluation/data/eval_cn_examples.json b/applications/ColossalEval/configs/gpt_evaluation/data/eval_cn_examples.json
new file mode 100644
index 000000000000..f869830555b4
--- /dev/null
+++ b/applications/ColossalEval/configs/gpt_evaluation/data/eval_cn_examples.json
@@ -0,0 +1,202 @@
+[
+ {
+ "category": "brainstorming",
+ "instruction": "列举一些可以促进头发生长的食物。",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 1
+ },
+ {
+ "category": "brainstorming",
+ "instruction": "中年夫妻如何提升夫妻感情,请给出三个实用的的方法,并举例说明。",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 2
+ },
+ {
+ "category": "brainstorming",
+ "instruction": "请列举4种日常的环保行为。",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 3
+ },
+ {
+ "category": "brainstorming",
+ "instruction": "请给出5个可以随时随地锻炼身体的小动作。",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 4
+ },
+ {
+ "category": "brainstorming",
+ "instruction": "请问如何制作一份美味的西红柿炒鸡蛋?",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 5
+ },
+ {
+ "category": "chat",
+ "instruction": "基于以下角色信息完成一段对话。小张是一名新手爱好者,对养鸡有浓厚的兴趣。老李是一名有丰富经验的养鸡大师。",
+ "input": "小张:您好,老李,我最近开始对养鸡感兴趣了,想请教您一些问题。 老李:你好,小张,我很乐意帮助你。你想问些什么? 小张:我想知道如何确定鸡的品种和性别? 老李:确切的品种可以通过鸡的外貌特征来确定,而性别一般是通过鸡卵的大小和形状来判断。还有什么问题吗? 小张:",
+ "output": "",
+ "target": "",
+ "id": 6
+ },
+ {
+ "category": "chat",
+ "instruction": "基于以下角色信息完成一段对话。李华是一名参加了期末考试的学生,他已经很担心自己的考试成绩。老师Lucy正在帮助他度过这个紧张的时刻。",
+ "input": "李华:Lucy老师,我很担心自己的考试成绩,我不知道我是否能够通过这次考试。 Lucy:放松,李华,你已经做好了充分的准备。相信你自己,你会做得很好的。 李华:我很怕考试时会忘记自己所学的知识。 Lucy:你可以预留一些时间,过一遍自己所学的知识点或笔记,这样你会更有信心和准确地回答考题。 李华:如果我还是失败了,该怎么办? Lucy:",
+ "output": "",
+ "target": "",
+ "id": 7
+ },
+ {
+ "category": "chat",
+ "instruction": "基于以下角色信息完成一段对话。张先生是一名企业家,正在考虑是否开拓海外市场;李女士是一名跨境电商专家,擅长国际商务和电子商务。",
+ "input": "张先生:你好,李女士,我正在考虑将我们的产品销售扩大至海外市场,您有什么建议吗? 李女士:您好,张先生,我们需要考虑到海外市场对于产品的需求是否与国内市场一致,需要进行市场调研和定位。然后再进行各种软性、硬性的创新。 张先生:听起来很专业,您能具体解释一下吗? 李女士:",
+ "output": "",
+ "target": "",
+ "id": 8
+ },
+ {
+ "category": "chat",
+ "instruction": "基于以下角色信息完成一段对话。小明是一名医生。一名病患想要提前停药。小王是病患的儿子,希望父亲能够听取医生的建议。",
+ "input": "小明:你好,小王,我了解你想要让你父亲停药。小王:是的,我父亲已经吃了那么久的药,我担心药物对他的身体会有副作用。小明:",
+ "output": "",
+ "target": "",
+ "id": 9
+ },
+ {
+ "category": "chat",
+ "instruction": "基于以下角色信息完成一段对话。张三是一位语文老师,对学生认真负责;李四是张三的学生,对语文兴趣不是很高。",
+ "input": "张三:同学们,今天要讲的是一篇古文《岳阳楼记》。这篇文章非常精彩,希望同学们能够认真听课,理解其中的含义。 李四:怎么又是古文? 张三:",
+ "output": "",
+ "target": "",
+ "id": 10
+ },
+ {
+ "category": "generation",
+ "instruction": "根据主题写一封邮件。",
+ "input": "主题: \"加入我们,共创未来\"",
+ "output": "",
+ "target": "",
+ "id": 11
+ },
+ {
+ "category": "generation",
+ "instruction": "为公司编写一份职场行为准则,包括明确的行为规范和道德准则。",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 12
+ },
+ {
+ "category": "generation",
+ "instruction": "请撰写一篇文章,介绍如何通过改善生活习惯来预防疾病和延长寿命。",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 13
+ },
+ {
+ "category": "generation",
+ "instruction": "请为一家咖啡店编写一篇简短的广告语,吸引更多的顾客。",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 14
+ },
+ {
+ "category": "generation",
+ "instruction": "根据以下故事提示写一篇故事:",
+ "input": "故事提示:```在一个废弃的古堡中,一个小女孩遇到了一只会说话的黑猫,他们一起揭开了一个古老的谜题。```",
+ "output": "",
+ "target": "",
+ "id": 15
+ },
+ {
+ "category": "open_qa",
+ "instruction": "请介绍一下《红楼梦》这部经典小说的故事情节。",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 16
+ },
+ {
+ "category": "open_qa",
+ "instruction": "解释什么是RNA病毒和DNA病毒。",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 17
+ },
+ {
+ "category": "open_qa",
+ "instruction": "什么是比特币?",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 18
+ },
+ {
+ "category": "open_qa",
+ "instruction": "在计算机中,什么是RAM?与ROM有什么区别?",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 19
+ },
+ {
+ "category": "open_qa",
+ "instruction": "请简单介绍一下世界上最长的河流途经的国家。",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 20
+ },
+ {
+ "category": "roleplay",
+ "instruction": "我要你把我写的句子翻译成表情符号。我会写句子,你会用表情符号表达它。我只是想让你用表情符号来表达它。除了表情符号,我不希望你回复任何内容。当我需要用中文告诉你一些事情时,我会用 {} 这样的大括号括起来。我的第一句话是“{我的职业是消防员。}”\n",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 21
+ },
+ {
+ "category": "roleplay",
+ "instruction": "我希望你假定自己是雅思写作考官,根据雅思评判标准,按我给你的雅思考题和对应答案给我评分,并且按照雅思写作评分细则给出打分依据。此外,请给我详细的修改意见并写出满分范文。第一个问题是:It is sometimes argued that too many students go to university, while others claim that a university education should be a universal right. Discuss both sides of the argument and give your own opinion.对于这个问题,我的答案是:In some advanced countries, it is not unusual for more than 50% of young adults to attend college or university. Critics, however, claim that many university courses are worthless and young people would be better off gaining skills in the workplace. In this essay, I will examine both sides of this argument and try to reach a conclusion.There are several reasons why young people today believe they have the right to a university education. First, growing prosperity in many parts of the world has increased the number of families with money to invest in their children’s future. At the same time, falling birthrates mean that one- or two-child families have become common, increasing the level of investment in each child. It is hardly surprising, therefore, that young people are willing to let their families support them until the age of 21 or 22. Furthermore, millions of new jobs have been created in knowledge industries, and these jobs are typically open only to university graduates.However, it often appears that graduates end up in occupations unrelated to their university studies. It is not uncommon for an English literature major to end up working in sales, or an engineering graduate to retrain as a teacher, for example. Some critics have suggested that young people are just delaying their entry into the workplace, rather than developing professional skills.请依次给到我以下内容:具体分数及其评分依据、文章修改意见、满分范文。\n",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 22
+ },
+ {
+ "category": "roleplay",
+ "instruction": "我想让你充当 Linux 终端。我将输入命令,您将回复终端应显示的内容。我希望您只在一个唯一的代码块内回复终端输出,而不是其他任何内容。不要写解释。除非我指示您这样做,否则不要键入命令。当我需要用英语告诉你一些事情时,我会把文字放在中括号内[就像这样]。我的第一个命令是 pwd\n",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 23
+ },
+ {
+ "category": "roleplay",
+ "instruction": "我希望你充当宠物行为主义者。我将为您提供一只宠物和它们的主人,您的目标是帮助主人了解为什么他们的宠物表现出某些行为,并提出帮助宠物做出相应调整的策略。您应该利用您的动物心理学知识和行为矫正技术来制定一个有效的计划,双方的主人都可以遵循,以取得积极的成果。我的第一个请求是“我有一只好斗的德国牧羊犬,它需要帮助来控制它的攻击性。”\n",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 24
+ },
+ {
+ "category": "roleplay",
+ "instruction": "我希望你充当正则表达式生成器。您的角色是生成匹配文本中特定模式的正则表达式。您应该以一种可以轻松复制并粘贴到支持正则表达式的文本编辑器或编程语言中的格式提供正则表达式。不要写正则表达式如何工作的解释或例子;只需提供正则表达式本身。我的第一个提示是生成一个匹配电子邮件地址的正则表达式。\n",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 25
+ }
+]
diff --git a/applications/ColossalEval/configs/gpt_evaluation/data/eval_en_examples.json b/applications/ColossalEval/configs/gpt_evaluation/data/eval_en_examples.json
new file mode 100644
index 000000000000..27b8af8bc4c6
--- /dev/null
+++ b/applications/ColossalEval/configs/gpt_evaluation/data/eval_en_examples.json
@@ -0,0 +1,202 @@
+[
+ {
+ "category": "brainstorming",
+ "instruction": "Which are some popular fiction books that I should read?",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 1
+ },
+ {
+ "category": "brainstorming",
+ "instruction": "How do I properly store fruits and vegetables to keep them fresh for longer?",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 2
+ },
+ {
+ "category": "brainstorming",
+ "instruction": "How do you properly chop an onion without crying?",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 3
+ },
+ {
+ "category": "brainstorming",
+ "instruction": "How to make an international transfer? Please provide 3 techniques.",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 4
+ },
+ {
+ "category": "brainstorming",
+ "instruction": "Name five leadership qualities that you consider most important.",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 5
+ },
+ {
+ "category": "chat",
+ "instruction": "Complete a dialogue based on the following character information. Alex: A novice writer who is struggling to find inspiration and develop his writing skills. Emma: A successful author with many published works, providing guidance and advice to Alex.",
+ "input": "Alex: Hi Emma, I have been writing for a while now but can't seem to make any progress. Can you give me any advice? Emma: Hi Alex, sure. What kind of writing are you doing? Alex: I'm trying to write a novel, but I just can't seem to find any inspiration. Emma: ",
+ "output": "",
+ "target": "",
+ "id": 6
+ },
+ {
+ "category": "chat",
+ "instruction": "Complete a dialogue based on the following character information. John: An experienced software engineer with a passion for coding. Karen: A recent college graduate who is interested in learning more about software development.",
+ "input": "Karen: Hi John, I noticed that you have a lot of experience in the software industry. Can you tell me what you think is the most important skill for a software engineer? John: ",
+ "output": "",
+ "target": "",
+ "id": 7
+ },
+ {
+ "category": "chat",
+ "instruction": "Complete a dialogue based on the following character information. Sarah is a new employee who is nervous about her first presentation; Tom is her boss who has given her coaching and preparation materials.",
+ "input": "Sarah: Tom, I'm feeling really nervous about my presentation tomorrow. Tom: I know how you feel, Sarah. However, I believe in you and your abilities. Just stick to the preparation materials that I have given you, and you'll do great. Sarah: Thank you, Tom. What if I forget something important during the presentation? Tom: ",
+ "output": "",
+ "target": "",
+ "id": 8
+ },
+ {
+ "category": "chat",
+ "instruction": "Complete a dialogue based on the following character information. Sarah: a young artist who is full of creative ideas and always eager to try new things. Jack: a seasoned artist who has achieved great success in the art world and is more traditional in his approach to art.",
+ "input": "Sarah: Hi Jack, I'm really excited to meet you. I'm a big fan of your work. Jack: Hi Sarah, nice to meet you too. So, what kind of art do you do? Sarah: I am passionate about abstract art, especially combining different materials and colors. I think it can really give people a new perspective on things. Jack: That's interesting, but I am more focused on realistic paintings. I believe the most important thing is to master the basic skills first. Sarah: ",
+ "output": "",
+ "target": "",
+ "id": 9
+ },
+ {
+ "category": "chat",
+ "instruction": "Complete a conversation based on the following persona information. Sarah is a college student who is interested in joining a volunteer organization. John is the leader of the volunteer organization and is eager to welcome new members.",
+ "input": "Sarah: Hi, I'm Sarah, and I'm interested in joining your volunteer organization. John: Hi Sarah, welcome! We're always looking for new members who are passionate about volunteering. What areas would you like to focus on? Sarah: I'm interested in community outreach and working with children. John: ",
+ "output": "",
+ "target": "",
+ "id": 10
+ },
+ {
+ "category": "generation",
+ "instruction": "Write an email based on the subject:",
+ "input": "Subject: \"Invitation to an Exclusive Webinar\"",
+ "output": "",
+ "target": "",
+ "id": 11
+ },
+ {
+ "category": "generation",
+ "instruction": "Write a set of guidelines for first-time pet owners on how to properly care for a new puppy.",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 12
+ },
+ {
+ "category": "generation",
+ "instruction": "Can you help me write a persuasive speech on why we should recycle more and take better care of the environment?",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 13
+ },
+ {
+ "category": "generation",
+ "instruction": "Write a pitch for a brand-new mobile app that helps people organize their daily tasks more efficiently.",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 14
+ },
+ {
+ "category": "generation",
+ "instruction": "Write a social media post promoting an upcoming concert featuring a local band.",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 15
+ },
+ {
+ "category": "open_qa",
+ "instruction": "Describe the significance of the Renaissance period in European history.",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 16
+ },
+ {
+ "category": "open_qa",
+ "instruction": "What is the term for the surgical removal of the appendix?",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 17
+ },
+ {
+ "category": "open_qa",
+ "instruction": "Explain the process of osmosis in biological systems.",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 18
+ },
+ {
+ "category": "open_qa",
+ "instruction": "Who were the members of the Beatles band?",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 19
+ },
+ {
+ "category": "open_qa",
+ "instruction": "Who painted the The Scream?",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 20
+ },
+ {
+ "category": "roleplay",
+ "instruction": "I want you to act as a linux terminal. I will type commands and you will reply with what the terminal should show. I want you to only reply with the terminal output inside one unique code block, and nothing else. do not write explanations. do not type commands unless I instruct you to do so. when i need to tell you something in english, i will do so by putting text inside curly brackets {like this}. my first command is pwd",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 21
+ },
+ {
+ "category": "roleplay",
+ "instruction": "I want you to act as a travel guide. I will write you my location and you will suggest a place to visit near my location. In some cases, I will also give you the type of places I will visit. You will also suggest me places of similar type that are close to my first location. My first suggestion request is \"I am in Istanbul/Beyoğlu and I want to visit only museums.\"",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 22
+ },
+ {
+ "category": "roleplay",
+ "instruction": "I want you to act as an advertiser. You will create a campaign to promote a product or service of your choice. You will choose a target audience, develop key messages and slogans, select the media channels for promotion, and decide on any additional activities needed to reach your goals. My first suggestion request is \"I need help creating an advertising campaign for a new type of energy drink targeting young adults aged 18-30.\"",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 23
+ },
+ {
+ "category": "roleplay",
+ "instruction": "I want you to act as a storyteller. You will come up with entertaining stories that are engaging, imaginative and captivating for the audience. It can be fairy tales, educational stories or any other type of stories which has the potential to capture people's attention and imagination. Depending on the target audience, you may choose specific themes or topics for your storytelling session e.g., if it’s children then you can talk about animals; If it’s adults then history-based tales might engage them better etc. My first request is \"I need an interesting story on perseverance.\"",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 24
+ },
+ {
+ "category": "roleplay",
+ "instruction": "I want you to act as a rapper. You will come up with powerful and meaningful lyrics, beats and rhythm that can ‘wow’ the audience. Your lyrics should have an intriguing meaning and message which people can relate too. When it comes to choosing your beat, make sure it is catchy yet relevant to your words, so that when combined they make an explosion of sound everytime! My first request is \"I need a rap song about finding strength within yourself.\"",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 25
+ }
+]
diff --git a/applications/Chat/evaluate/prompt/battle_prompt/battle_prompt_cn.json b/applications/ColossalEval/configs/gpt_evaluation/prompt/battle_prompt/battle_prompt_cn.json
similarity index 100%
rename from applications/Chat/evaluate/prompt/battle_prompt/battle_prompt_cn.json
rename to applications/ColossalEval/configs/gpt_evaluation/prompt/battle_prompt/battle_prompt_cn.json
diff --git a/applications/Chat/evaluate/prompt/battle_prompt/battle_prompt_en.json b/applications/ColossalEval/configs/gpt_evaluation/prompt/battle_prompt/battle_prompt_en.json
similarity index 100%
rename from applications/Chat/evaluate/prompt/battle_prompt/battle_prompt_en.json
rename to applications/ColossalEval/configs/gpt_evaluation/prompt/battle_prompt/battle_prompt_en.json
diff --git a/applications/Chat/evaluate/prompt/evaluation_prompt/evaluation_prompt_cn.json b/applications/ColossalEval/configs/gpt_evaluation/prompt/evaluation_prompt/evaluation_prompt_cn.json
similarity index 56%
rename from applications/Chat/evaluate/prompt/evaluation_prompt/evaluation_prompt_cn.json
rename to applications/ColossalEval/configs/gpt_evaluation/prompt/evaluation_prompt/evaluation_prompt_cn.json
index dccab2417eee..70f6c3ebc316 100644
--- a/applications/Chat/evaluate/prompt/evaluation_prompt/evaluation_prompt_cn.json
+++ b/applications/ColossalEval/configs/gpt_evaluation/prompt/evaluation_prompt/evaluation_prompt_cn.json
@@ -39,53 +39,8 @@
},
"prompt": "你是一个好助手。请你为下面的“补全对话”问题的答案打分。\n\n问题如下:\n\n{question}\n\n答案如下:\n\n{answer}\n\n评分的指标如下:\n\n{metric}\n\n请你遵照以下的评分步骤:\n\n{steps}"
},
- "classification": {
- "id": 3,
- "category": "classification",
- "metrics": {
- "language organization": "语言组织(1-5):答案语言是否流畅、连贯,使用正确的语法,具有一定逻辑性,使用恰当的连接词、过渡词等等。",
- "relevance": "切题(1-5):答案内容是否切题,不答非所问,并且严格遵照题目要求。",
- "correctness": "正确性(1-5):答案是否正确。"
- },
- "CoT": {
- "language organization": "1. 阅读答案,并检查是否有语法错误、用词不当或其他显著的错误。\n2. 检查答案是否具有逻辑性,能够按照合理的顺序传达信息并且能够自圆其说。\n3. 确定答案是否与问题或主题相关,并且能够传达清晰的信息。\n4. 检查答案是否连贯,是否使用适当的转换和过渡来保持句子和段落之间的连贯性。\n5. 检查答案是否具有明确的结构和组织方式,使得读者可以轻松理解信息的层次和结构。\n6. 根据以上因素综合评估答案的语言组织,并给出一个1到5的分数,其中5表示语言组织非常好,而1表示语言组织非常差。\n\n语言组织:",
- "relevance": "1. 阅读题目,确定题目所问的问题是什么,以及需要回答哪些方面的问题。\n2. 阅读答案,确认答案是否直接回答了题目所问的问题。\n3. 检查答案是否严格遵照了题目的要求,包括答题方式、答题长度、答题格式等等。\n4. 根据以上因素综合评估答案的切题程度,并给出一个1到5的分数,其中5表示答案非常切题,而1表示答案完全没有切题。\n\n切题:",
- "correctness": "1. 仔细阅读题目,尝试自己回答该问题。\n2. 检查答案的准确性。您可以使用已知的事实或研究来验证答案是否正确。如果答案是正确的,则可以将正确性得分为5分。如果答案是部分正确的,则可以给予适当的得分,例如2分、3分或4分。如果答案完全不正确,则只得1分。\n\n正确性:"
- },
- "prompt": "你是一个好助手。请你为下面的“分类“问题的答案打分。\n\n问题如下:\n\n{question}\n\n答案如下:\n\n{answer}\n\n评分的指标如下:\n\n{metric}\n\n请你遵照以下的评分步骤:\n\n{steps}"
- },
- "closed_qa": {
- "id": 4,
- "category": "closed_qa",
- "metrics": {
- "language organization": "语言组织(1-5):答案语言是否流畅、连贯,使用正确的语法,具有一定逻辑性,使用恰当的连接词、过渡词等等。",
- "relevance": "切题(1-5):答案内容是否切题,不答非所问,并且严格遵照题目要求。",
- "correctness": "正确性(1-5):答案是否正确。"
- },
- "CoT": {
- "language organization": "1. 阅读答案,并检查是否有语法错误、用词不当或其他显著的错误。\n2. 检查答案是否具有逻辑性,能够按照合理的顺序传达信息并且能够自圆其说。\n3. 确定答案是否与问题或主题相关,并且能够传达清晰的信息。\n4. 检查答案是否连贯,是否使用适当的转换和过渡来保持句子和段落之间的连贯性。\n5. 检查答案是否具有明确的结构和组织方式,使得读者可以轻松理解信息的层次和结构。\n6. 根据以上因素综合评估答案的语言组织,并给出一个1到5的分数,其中5表示语言组织非常好,而1表示语言组织非常差。\n\n语言组织:",
- "relevance": "1. 阅读题目,确定题目所问的问题是什么,以及需要回答哪些方面的问题。\n2. 阅读答案,确认答案是否直接回答了题目所问的问题。\n3. 检查答案是否严格遵照了题目的要求,包括答题方式、答题长度、答题格式等等。\n4. 根据以上因素综合评估答案的切题程度,并给出一个1到5的分数,其中5表示答案非常切题,而1表示答案完全没有切题。\n\n切题:",
- "correctness": "1. 仔细阅读题目,尝试自己回答该问题。\n2. 检查答案的准确性。您可以使用已知的事实或研究来验证答案是否正确。如果答案是正确的,则可以将正确性得分为5分。如果答案是部分正确的,则可以给予适当的得分,例如2分、3分或4分。如果答案完全不正确,则只得1分。\n\n正确性:"
- },
- "prompt": "你是一个好助手。请你为下面问题的答案打分。\n\n问题如下:\n\n{question}\n\n需要你评分的答案如下:\n\n{answer}\n\n评分的指标如下:\n\n{metric}\n\n请你遵照以下的评分步骤:\n\n{steps}"
- },
- "extraction": {
- "id": 5,
- "category": "extraction",
- "metrics": {
- "language organization": "语言组织(1-5):答案语言是否流畅、连贯,使用正确的语法,具有一定逻辑性,使用恰当的连接词、过渡词等等。",
- "relevance": "切题(1-5):答案内容是否切题,不答非所问,并且严格遵照题目要求。",
- "correctness": "准确性(1-5):回答应该准确无误地提取出所需信息,不应该包含任何错误或误导性信息。"
- },
- "CoT": {
- "language organization": "1. 阅读答案,并检查是否有语法错误、用词不当或其他显著的错误。\n2. 检查答案是否具有逻辑性,能够按照合理的顺序传达信息并且能够自圆其说。\n3. 确定答案是否与问题或主题相关,并且能够传达清晰的信息。\n4. 检查答案是否连贯,是否使用适当的转换和过渡来保持句子和段落之间的连贯性。\n5. 检查答案是否具有明确的结构和组织方式,使得读者可以轻松理解信息的层次和结构。\n6. 根据以上因素综合评估答案的语言组织,并给出一个1到5的分数,其中5表示语言组织非常好,而1表示语言组织非常差。\n\n语言组织:",
- "relevance": "1. 阅读题目,确定题目所问的问题是什么,以及需要回答哪些方面的问题。\n2. 阅读答案,确认答案是否直接回答了题目所问的问题。\n3. 检查答案是否严格遵照了题目的要求,包括答题方式、答题长度、答题格式等等。\n4. 根据以上因素综合评估答案的切题程度,并给出一个1到5的分数,其中5表示答案非常切题,而1表示答案完全没有切题。\n\n切题:",
- "correctness": "1. 仔细阅读问题并确定需要从材料中提取的信息。\n2. 仔细阅读回答并确保它涵盖了所有需要提取的信息。\n3. 使用所提供的材料来验证回答的准确性。如果回答不准确或包含错误或误导性信息,则无法给出高分。\n4. 检查回答是否包含所有要求提取的信息,不要漏掉任何重要细节。\n5. 根据回答的准确性和完整性,给出一个介于1和5之间的分数,5分表示回答非常准确且完整,1分表示回答几乎没有提取出所需信息。\n\n准确性:"
- },
- "prompt": "你是一个好助手。请你为下面的“提取”问题的答案打分。\n\n问题如下:\n\n{question}\n\n答案如下:\n\n{answer}\n\n评分的指标如下:\n\n{metric}\n\n请你遵照以下的评分步骤:\n\n{steps}"
- },
"generation": {
- "id": 6,
+ "id": 3,
"category": "generation",
"metrics": {
"language organization": "语言组织(1-5):答案语言是否流畅、连贯,使用正确的语法,具有一定逻辑性,使用恰当的连接词、过渡词等等。",
@@ -100,7 +55,7 @@
"prompt": "你是一个好助手。请你为下面的“生成”问题的答案打分。\n\n问题如下:\n\n{question}\n\n答案如下:\n\n{answer}\n\n评分的指标如下:\n\n{metric}\n\n请你遵照以下的评分步骤:\n\n{steps}"
},
"open_qa": {
- "id": 7,
+ "id": 4,
"category": "open_qa",
"metrics": {
"language organization": "语言组织(1-5):答案语言是否流畅、连贯,使用正确的语法,具有一定逻辑性,使用恰当的连接词、过渡词等等。",
@@ -114,23 +69,8 @@
},
"prompt": "你是一个好助手。请你为下面的问题的答案打分。\n\n问题如下:\n\n{question}\n\n答案如下:\n\n{answer}\n\n评分的指标如下:\n\n{metric}\n\n请你遵照以下的评分步骤:\n\n{steps}"
},
- "rewriting": {
- "id": 8,
- "category": "rewriting",
- "metrics": {
- "language organization": "语言组织(1-5):答案语言是否流畅、连贯,使用正确的语法,具有一定逻辑性,使用恰当的连接词、过渡词等等。",
- "relevance": "切题(1-5):答案内容是否切题,不答非所问,并且严格遵照题目要求。",
- "correctness": "正确性(1-5):答案是否正确。"
- },
- "CoT": {
- "language organization": "1. 阅读答案,并检查是否有语法错误、用词不当或其他显著的错误。\n2. 检查答案是否具有逻辑性,能够按照合理的顺序传达信息并且能够自圆其说。\n3. 确定答案是否与问题或主题相关,并且能够传达清晰的信息。\n4. 检查答案是否连贯,是否使用适当的转换和过渡来保持句子和段落之间的连贯性。\n5. 检查答案是否具有明确的结构和组织方式,使得读者可以轻松理解信息的层次和结构。\n6. 根据以上因素综合评估答案的语言组织,并给出一个1到5的分数,其中5表示语言组织非常好,而1表示语言组织非常差。\n\n语言组织:",
- "relevance": "1. 阅读题目,确定题目所问的问题是什么,以及需要回答哪些方面的问题。\n2. 阅读答案,确认答案是否直接回答了题目所问的问题。\n3. 检查答案是否严格遵照了题目的要求,包括答题方式、答题长度、答题格式等等。\n4. 根据以上因素综合评估答案的切题程度,并给出一个1到5的分数,其中5表示答案非常切题,而1表示答案完全没有切题。\n\n切题:",
- "correctness": "1. 仔细阅读题目,尝试自己回答该问题。\n2. 检查答案的准确性。您可以使用已知的事实或研究来验证答案是否正确。如果答案是正确的,则可以将正确性得分为5分。如果答案是部分正确的,则可以给予适当的得分,例如2分、3分或4分。如果答案完全不正确,则只得1分。\n\n正确性:"
- },
- "prompt": "你是一个好助手。请你为下面的问题的答案打分。\n\n问题如下:\n\n{question}\n\n答案如下:\n\n{answer}\n\n评分的指标如下:\n\n{metric}\n\n请你遵照以下的评分步骤:\n\n{steps}"
- },
"roleplay": {
- "id": 9,
+ "id": 5,
"category": "roleplay",
"metrics": {
"language organization": "语言组织(1-5):答案语言是否流畅、连贯,使用正确的语法,具有一定逻辑性,使用恰当的连接词、过渡词等等。",
@@ -146,33 +86,14 @@
},
"prompt": "你是一个好助手。请你为下面的“角色扮演”问题的答案打分。\n\n问题如下:\n\n{question}\n\n答案如下:\n\n{answer}\n\n评分的指标如下:\n\n{metric}\n\n请你遵照以下的评分步骤:\n\n{steps}"
},
- "summarization": {
- "id": 10,
- "category": "summarization",
- "metrics": {
- "language organization": "语言组织(1-5):答案语言是否流畅、连贯,使用正确的语法,具有一定逻辑性,使用恰当的连接词、过渡词等等。",
- "relevance": "切题(1-5):答案内容是否切题,不答非所问,并且严格遵照题目要求。",
- "correctness": "准确性(1-5):回答应该准确无误地总结出材料的重点。",
- "conciseness": "简明扼要(1-5):答案是否简明扼要,没有冗余内容。"
- },
- "CoT": {
- "language organization": "1. 阅读答案,并检查是否有语法错误、用词不当或其他显著的错误。\n2. 检查答案是否具有逻辑性,能够按照合理的顺序传达信息并且能够自圆其说。\n3. 确定答案是否与问题或主题相关,并且能够传达清晰的信息。\n4. 检查答案是否连贯,是否使用适当的转换和过渡来保持句子和段落之间的连贯性。\n5. 检查答案是否具有明确的结构和组织方式,使得读者可以轻松理解信息的层次和结构。\n6. 根据以上因素综合评估答案的语言组织,并给出一个1到5的分数,其中5表示语言组织非常好,而1表示语言组织非常差。\n\n语言组织:",
- "relevance": "1. 阅读题目,确定题目所问的问题是什么,以及需要回答哪些方面的问题。\n2. 阅读答案,确认答案是否直接回答了题目所问的问题。\n3. 检查答案是否严格遵照了题目的要求,包括答题方式、答题长度、答题格式等等。\n4. 根据以上因素综合评估答案的切题程度,并给出一个1到5的分数,其中5表示答案非常切题,而1表示答案完全没有切题。\n\n切题:",
- "correctness": "1. 仔细阅读问题给的材料,理解其内容和要点。\n2. 评估回答是否准确地总结出原始材料的重点。\n3. 评估回答是否包含原始材料中的所有关键信息。\n4. 根据以上步骤,给出一个1-5的分数,其中1表示回答不能准确地总结出材料的重点,5表示回答完全准确地总结出材料的重点。\n\n准确性:",
- "conciseness": "1. 阅读题目,提取出材料的重点。\n2. 阅读该总结,并注意其中的主要观点和信息。\n3. 评估总结的长度。一个简明扼要的总结通常应该在几句话或几段文字内传达关键信息,而不是冗长的段落或文章。\n4. 检查总结是否包含与主要观点无关的信息或冗余信息。\n5.确定总结涵盖了材料中的关键信息,并且没有忽略任何重要细节。\n6.给总结打出1-5的分数,其中5表示总结简明扼要,没有冗余内容,而1表示总结冗长或包含不必要的信息,难以理解或记忆。根据您的判断,打出适当的得分。\n\n简明扼要:"
- },
- "prompt": "你是一个好助手。请你为下面的“总结”问题的答案打分。\n\n问题如下:\n\n{question}\n\n答案如下:\n\n{answer}\n\n评分的指标如下:\n\n{metric}\n\n请你遵照以下的评分步骤:\n\n{steps}"
- },
- "general": {
- "id": 11,
- "category": "general",
+ "Other": {
+ "id": 6,
+ "category": "Other",
"metrics": {
- "language organization": "语言组织(1-5):答案语言是否流畅、连贯,使用正确的语法,具有一定逻辑性,使用恰当的连接词、过渡词等等。",
"relevance": "切题(1-5):答案内容是否切题,不答非所问,并且严格遵照题目要求。",
"correctness": "正确性(1-5):答案是否正确。"
},
"CoT": {
- "language organization": "1. 阅读答案,并检查是否有语法错误、用词不当或其他显著的错误。\n2. 检查答案是否具有逻辑性,能够按照合理的顺序传达信息并且能够自圆其说。\n3. 确定答案是否与问题或主题相关,并且能够传达清晰的信息。\n4. 检查答案是否连贯,是否使用适当的转换和过渡来保持句子和段落之间的连贯性。\n5. 检查答案是否具有明确的结构和组织方式,使得读者可以轻松理解信息的层次和结构。\n6. 根据以上因素综合评估答案的语言组织,并给出一个1到5的分数,其中5表示语言组织非常好,而1表示语言组织非常差。\n\n语言组织:",
"relevance": "1. 阅读题目,确定题目所问的问题是什么,以及需要回答哪些方面的问题。\n2. 阅读答案,确认答案是否直接回答了题目所问的问题。\n3. 检查答案是否严格遵照了题目的要求,包括答题方式、答题长度、答题格式等等。\n4. 根据以上因素综合评估答案的切题程度,并给出一个1到5的分数,其中5表示答案非常切题,而1表示答案完全没有切题。\n\n切题:",
"correctness": "1. 仔细阅读题目,尝试自己回答该问题。\n2. 检查答案的准确性。您可以使用已知的事实或研究来验证答案是否正确。如果答案是正确的,则可以将正确性得分为5分。如果答案是部分正确的,则可以给予适当的得分,例如2分、3分或4分。如果答案完全不正确,则只得1分。\n\n正确性:"
},
diff --git a/applications/Chat/evaluate/prompt/evaluation_prompt/evaluation_prompt_en.json b/applications/ColossalEval/configs/gpt_evaluation/prompt/evaluation_prompt/evaluation_prompt_en.json
similarity index 59%
rename from applications/Chat/evaluate/prompt/evaluation_prompt/evaluation_prompt_en.json
rename to applications/ColossalEval/configs/gpt_evaluation/prompt/evaluation_prompt/evaluation_prompt_en.json
index 8355b0c27b79..3d04387d98c5 100644
--- a/applications/Chat/evaluate/prompt/evaluation_prompt/evaluation_prompt_en.json
+++ b/applications/ColossalEval/configs/gpt_evaluation/prompt/evaluation_prompt/evaluation_prompt_en.json
@@ -39,53 +39,8 @@
},
"prompt": "You are a good assistant. Please rate the given answer to the \"chat\" question below.\n\nThe question is as follows:\n\n{question}\n\nThe answer is as follows:\n\n{answer}\n\nThe metric for evaluation is as follows:\n\n{metric}\n\nYou should follow the following evaluation steps:\n\n{steps}"
},
- "classification": {
- "id": 3,
- "category": "classification",
- "metrics": {
- "language organization": "Language organization (1-5): whether the answer language is fluent and coherent, uses correct grammar, has a certain logic, uses appropriate connecting words, transition words, etc.",
- "relevance": "Relevance (1-5): whether the content of the answer is relevant to the topic, does not answer the wrong question, and strictly follows the requirements of the topic.",
- "correctness": "Correctness (1-5): whether the answer is correct or not."
- },
- "CoT": {
- "language organization": "1. Read the answers and check for grammatical errors, poor word choice, or other significant mistakes.\n2. Check that the answer is logical, conveys the information in a logical order, and is self-explanatory.\n3. Determine if the answer is relevant to the question or topic and conveys a clear message.\n4. Check that the answer is coherent and that appropriate transitions and switches are used to maintain coherence between sentences and paragraphs.\n5. Check that the answer is clearly structured and organized in such a way that the reader can easily understand the hierarchy and structure of the information.\n6. Evaluate the language organization of the answer based on a combination of the above factors and give a score of 1 to 5, where 5 indicates very good language organization and 1 indicates very poor language organization.\n\nLanguage organization:",
- "relevance": "1. Read the question to determine what the question asks and what aspects of the question need to be answered.\n2. Read the answers to make sure that they directly answer the question asked.\n3. Check that the answer follows the requirements of the question, including the way it is answered, the length of the answer, the format of the answer, etc.\n4. Evaluate how relevant the answer is based on the above factors and give a score of 1 to 5, where 5 means the answer is very relevant and 1 means the answer is not relevant at all.\n\nRelevance:",
- "correctness": "1. Read the question carefully and try to answer the question yourself.\n2. Check the correctness of the answer. You can use known facts or research to verify that the answer is correct. If the answer is correct, you can give a score of 5 for correctness. If the answer is partially correct, an appropriate score, such as 2, 3, or 4, may be given. If the answer is completely incorrect, only 1 point is awarded.\n\nCorrectness:"
- },
- "prompt": "You are a good assistant. Please rate the given answer to the \"classification\" question below.\n\nThe question is as follows:\n\n{question}\n\nThe answer is as follows:\n\n{answer}\n\nThe metric for evaluation is as follows:\n\n{metric}\n\nYou should follow the following evaluation steps:\n\n{steps}"
- },
- "closed_qa": {
- "id": 4,
- "category": "closed_qa",
- "metrics": {
- "language organization": "Language organization (1-5): whether the answer language is fluent and coherent, uses correct grammar, has a certain logic, uses appropriate connecting words, transition words, etc.",
- "relevance": "Relevance (1-5): whether the content of the answer is relevant to the topic, does not answer the wrong question, and strictly follows the requirements of the topic.",
- "correctness": "Correctness (1-5): whether the answer is correct or not."
- },
- "CoT": {
- "language organization": "1. Read the answers and check for grammatical errors, poor word choice, or other significant mistakes.\n2. Check that the answer is logical, conveys the information in a logical order, and is self-explanatory.\n3. Determine if the answer is relevant to the question or topic and conveys a clear message.\n4. Check that the answer is coherent and that appropriate transitions and switches are used to maintain coherence between sentences and paragraphs.\n5. Check that the answer is clearly structured and organized in such a way that the reader can easily understand the hierarchy and structure of the information.\n6. Evaluate the language organization of the answer based on a combination of the above factors and give a score of 1 to 5, where 5 indicates very good language organization and 1 indicates very poor language organization.\n\nLanguage organization:",
- "relevance": "1. Read the question to determine what the question asks and what aspects of the question need to be answered.\n2. Read the answers to make sure that they directly answer the question asked.\n3. Check that the answer follows the requirements of the question, including the way it is answered, the length of the answer, the format of the answer, etc.\n4. Evaluate how relevant the answer is based on the above factors and give a score of 1 to 5, where 5 means the answer is very relevant and 1 means the answer is not relevant at all.\n\nRelevance:",
- "correctness": "1. Read the question carefully and try to answer the question by yourself.\n2. Check the correctness of the answer. You can use known facts or research to verify that the answer is correct. If the answer is correct, you can give a score of 5 for correctness. If the answer is partially correct, an appropriate score, such as 2, 3, or 4, may be assigned. If the answer is completely incorrect, only 1 point is awarded.\n\nCorrectness:"
- },
- "prompt": "You are a good assistant. Please rate the given answer to the \"closed qa\" question below.\n\nThe question is as follows:\n\n{question}\n\nThe answer is as follows:\n\n{answer}\n\nThe metric for evaluation is as follows:\n\n{metric}\n\nYou should follow the following evaluation steps:\n\n{steps}"
- },
- "extraction": {
- "id": 5,
- "category": "extraction",
- "metrics": {
- "language organization": "Language organization (1-5): whether the answer language is fluent and coherent, uses correct grammar, has a certain logic, uses appropriate connecting words, transition words, etc.",
- "relevance": "Relevance (1-5): whether the content of the answer is relevant to the topic, does not answer the wrong question, and strictly follows the requirements of the topic.",
- "correctness": "correctness (1-5): Answers should extract the required information accurately and should not contain any incorrect or misleading information."
- },
- "CoT": {
- "language organization": "1. Read the answers and check for grammatical errors, poor word choice, or other significant mistakes.\n2. Check that the answer is logical, conveys the information in a logical order, and is self-explanatory.\n3. Determine if the answer is relevant to the question or topic and conveys a clear message.\n4. Check that the answer is coherent and that appropriate transitions and switches are used to maintain coherence between sentences and paragraphs.\n5. Check that the answer is clearly structured and organized in such a way that the reader can easily understand the hierarchy and structure of the information.\n6. Evaluate the language organization of the answer based on a combination of the above factors and give a score of 1 to 5, where 5 indicates very good language organization and 1 indicates very poor language organization.\n\nLanguage organization:",
- "relevance": "1. Read the question to determine what the question asks and what aspects of the question need to be answered.\n2. Read the answers to make sure that they directly answer the question asked.\n3. Check that the answer follows the requirements of the question, including the way it is answered, the length of the answer, the format of the answer, etc.\n4. Evaluate how relevant the answer is based on the above factors and give a score of 1 to 5, where 5 means the answer is very relevant and 1 means the answer is not relevant at all.\n\nRelevance:",
- "correctness": "1. Read the questions carefully and identify the information that needs to be extracted from the material.\n2. Read the answer carefully and make sure it covers all the information that needs to be extracted.\n3. Use the material provided to verify the correctness of the response. If the response is inaccurate or contains incorrect or misleading information, a high score cannot be given.\n4. Check that the answer contains all the information required to be extracted and do not leave out any important details.\n5. Give a score between 1 and 5 based on the correctness and completeness of the response, with a score of 5 indicating a very accurate and complete response and a score of 1 indicating that the response barely extracts the required information.\n\nCorrectness:"
- },
- "prompt": "You are a good assistant. Please rate the given answer to the \"extraction\" question below.\n\nThe question is as follows:\n\n{question}\n\nThe answer is as follows:\n\n{answer}\n\nThe metric for evaluation is as follows:\n\n{metric}\n\nYou should follow the following evaluation steps:\n\n{steps}"
- },
"generation": {
- "id": 6,
+ "id": 3,
"category": "generation",
"metrics": {
"language organization": "Language organization (1-5): whether the answer language is fluent and coherent, uses correct grammar, has a certain logic, uses appropriate connecting words, transition words, etc.",
@@ -100,7 +55,7 @@
"prompt": "You are a good assistant. Please rate the given answer to the \"generation\" question below.\n\nThe question is as follows:\n\n{question}\n\nThe answer is as follows:\n\n{answer}\n\nThe metric for evaluation is as follows:\n\n{metric}\n\nYou should follow the following evaluation steps:\n\n{steps}"
},
"open_qa": {
- "id": 7,
+ "id": 4,
"category": "open_qa",
"metrics": {
"language organization": "Language organization (1-5): whether the answer language is fluent and coherent, uses correct grammar, has a certain logic, uses appropriate connecting words, transition words, etc.",
@@ -114,23 +69,8 @@
},
"prompt": "You are a good assistant. Please rate the answers to the \"open qa\" question below.\n\nThe question is as follows:\n\n{question}\n\nThe answer is as follows:\n\n{answer}\n\nThe metric for evaluation is as follows:\n\n{metric}\n\nYou should follow the following evaluation steps:\n\n{steps}"
},
- "rewriting": {
- "id": 8,
- "category": "rewriting",
- "metrics": {
- "language organization": "Language organization (1-5): whether the answer language is fluent and coherent, uses correct grammar, has a certain logic, uses appropriate connecting words, transition words, etc.",
- "relevance": "Relevance (1-5): whether the content of the answer is relevant to the topic, does not answer the wrong question, and strictly follows the requirements of the topic.",
- "correctness": "Correctness (1-5): whether the answer is correct or not."
- },
- "CoT": {
- "language organization": "1. Read the answers and check for grammatical errors, poor word choice, or other significant mistakes.\n2. Check that the answer is logical, conveys the information in a logical order, and is self-explanatory.\n3. Determine if the answer is relevant to the question or topic and conveys a clear message.\n4. Check that the answer is coherent and that appropriate transitions and switches are used to maintain coherence between sentences and paragraphs.\n5. Check that the answer is clearly structured and organized in such a way that the reader can easily understand the hierarchy and structure of the information.\n6. Evaluate the language organization of the answer based on a combination of the above factors and give a score of 1 to 5, where 5 indicates very good language organization and 1 indicates very poor language organization.\n\nLanguage organization:",
- "relevance": "1. Read the question to determine what the question asks and what aspects of the question need to be answered.\n2. Read the answers to make sure that they directly answer the question asked.\n3. Check that the answer follows the requirements of the question, including the way it is answered, the length of the answer, the format of the answer, etc.\n4. Evaluate how relevant the answer is based on the above factors and give a score of 1 to 5, where 5 means the answer is very relevant and 1 means the answer is not relevant at all.\n\nRelevance:",
- "correctness": "1. Read the question carefully and try to answer the question yourself.\n2. Check the correctness of the answer. You can use known facts or research to verify that the answer is correct. If the answer is correct, you can give a score of 5 for correctness. If the answer is partially correct, an appropriate score, such as 2, 3, or 4, may be assigned. If the answer is completely incorrect, only 1 point is awarded.\n\nCorrectness:"
- },
- "prompt": "You are a good assistant. Please rate the answers to the \"rewriting\" question below.\n\nThe question is as follows:\n\n{question}\n\nThe answer is as follows:\n\n{answer}\n\nThe metric for evaluation is as follows:\n\n{metric}\n\nYou should follow the following evaluation steps:\n\n{steps}"
- },
"roleplay": {
- "id": 9,
+ "id": 5,
"category": "roleplay",
"metrics": {
"language organization": "Language organization (1-5): whether the answer language is fluent and coherent, uses correct grammar, has a certain logic, uses appropriate connecting words, transition words, etc.",
@@ -146,35 +86,17 @@
},
"prompt": "You are a good assistant. Please rate the given answer to the \"role-play\" question below.\n\nThe question is as follows:\n\n{question}\n\nThe answer is as follows:\n\n{answer}\n\nThe metric for evaluation is as follows:\n\n{metric}\n\nYou should follow the following evaluation steps:\n\n{steps}"
},
- "summarization": {
- "id": 10,
- "category": "summarization",
- "metrics": {
- "language organization": "Language organization (1-5): whether the answer language is fluent and coherent, uses correct grammar, has a certain logic, uses appropriate connecting words, transition words, etc.",
- "relevance": "Relevance (1-5): whether the content of the answer is relevant to the topic, does not answer the wrong question, and strictly follows the requirements of the topic.",
- "correctness": "Correctness (1-5): answers should summarize the main points of the material accurately and unambiguously.",
- "conciseness": "Conciseness (1-5): answers should be concise and without redundant content."
- },
- "CoT": {
- "language organization": "1. Read the answers and check for grammatical errors, poor word choice, or other significant mistakes.\n2. Check that the answer is logical, conveys the information in a logical order, and is self-explanatory.\n3. Determine if the answer is relevant to the question or topic and conveys a clear message.\n4. Check that the answer is coherent and that appropriate transitions and switches are used to maintain coherence between sentences and paragraphs.\n5. Check that the answer is clearly structured and organized in such a way that the reader can easily understand the hierarchy and structure of the information.\n6. Evaluate the language organization of the answer based on a combination of the above factors and give a score of 1 to 5, where 5 indicates very good language organization and 1 indicates very poor language organization.\n\nLanguage organization:",
- "relevance": "1. Read the question to determine what the question asks and what aspects of the question need to be answered.\n2. Read the answers to make sure that they directly answer the question asked.\n3. Check that the answer follows the requirements of the question, including the way it is answered, the length of the answer, the format of the answer, etc.\n4. Evaluate how relevant the answer is based on the above factors and give a score of 1 to 5, where 5 means the answer is very relevant and 1 means the answer is not relevant at all.\n\nRelevance:",
- "correctness": "1. Read the material given in the question carefully to understand its content and main points.\n2. Assess whether the answer accurately summarizes the key points of the source material.\n3. assess whether the response contains all the key information in the source material.\n4. Based on the above steps, give a score of 1-5, where 1 means that the response does not accurately summarize the main points of the material and 5 means that the response completely accurately summarizes the main points of the material.\n\nCorrectness:",
- "conciseness": "1. Read the title and extract the main points of the material.\n2. Read the summary and note the main ideas and messages in it.\n3. Assess the length of the summary. A concise summary should usually convey key information within a few sentences or paragraphs, rather than lengthy paragraphs or essays.\n4. Check that the summary does not contain information that is not relevant to the main ideas or that is redundant.\n5. Make sure that the summary covers the key information in the material and that no important details have been omitted.\n6. Rate the summary on a scale of 1-5, where 5 means the summary is concise and free of redundancy, and 1 means the summary is lengthy or contains unnecessary information that is difficult to understand or remember. Based on your judgment, assign the appropriate score.\n\nConciseness:"
- },
- "prompt": "You are a good assistant. Please rate the given answer to the \"summarization\" question below.\n\nThe question is as follows:\n\n{question}\n\nThe answer is as follows:\n\n{answer}\n\nThe metric for evaluation is as follows:\n\n{metric}\n\nYou should follow the following evaluation steps:\n\n{steps}"
- },
- "general": {
- "id": 11,
- "category": "general",
+ "Other": {
+ "id": 6,
+ "category": "Other",
"metrics": {
- "language organization": "Language organization (1-5): whether the answer language is fluent and coherent, uses correct grammar, has a certain logic, uses appropriate connecting words, transition words, etc.",
"relevance": "Relevance (1-5): whether the content of the answer is relevant to the topic, does not answer the wrong question, and strictly follows the requirements of the topic.",
"correctness": "Correctness (1-5): whether the answer is correct or not."
},
"CoT": {
"language organization": "1. Read the answers and check for grammatical errors, poor word choice, or other significant mistakes.\n2. Check that the answer is logical, conveys the information in a logical order, and is self-explanatory.\n3. Determine if the answer is relevant to the question or topic and conveys a clear message.\n4. Check that the answer is coherent and that appropriate transitions and switches are used to maintain coherence between sentences and paragraphs.\n5. Check that the answer is clearly structured and organized in such a way that the reader can easily understand the hierarchy and structure of the information.\n6. Evaluate the language organization of the answer based on a combination of the above factors and give a score of 1 to 5, where 5 indicates very good language organization and 1 indicates very poor language organization.\n\nLanguage organization:",
"relevance": "1. Read the question to determine what the question asks and what aspects of the question need to be answered.\n2. Read the answers to make sure that they directly answer the question asked.\n3. Check that the answer follows the requirements of the question, including the way it is answered, the length of the answer, the format of the answer, etc.\n4. Evaluate how relevant the answer is based on the above factors and give a score of 1 to 5, where 5 means the answer is very relevant and 1 means the answer is not relevant at all.\n\nRelevance:",
- "correctness": "1. Read the question carefully and try to answer the question yourself.\n2. Check the correctness of the answer. You can use known facts or research to verify that the answer is correct. If the answer is correct, you can give a score of 5 for correctness. If the answer is partially correct, an appropriate score, such as 2, 3, or 4, may be assigned. If the answer is completely incorrect, only 1 point is awarded.\n\nCorrectness:"
+ "correctness": "1. Read the question carefully and try to answer the question by yourself.\n2. Check the correctness of the answer. You can use known facts or research to verify that the answer is correct. If the answer is correct, you can give a score of 5 for correctness. If the answer is partially correct, an appropriate score, such as 2, 3, or 4, may be assigned. If the answer is completely incorrect, only 1 point is awarded.\n\nCorrectness:"
},
"prompt": "You are a good assistant. Please rate the given answer to the question below.\n\nThe question is as follows:\n\n{question}\n\nThe answer is as follows:\n\n{answer}\n\nThe metric for evaluation is as follows:\n\n{metric}\n\nYou should follow the following evaluation steps:\n\n{steps}"
}
diff --git a/applications/ColossalEval/examples/dataset_evaluation/config/evaluation/config.json b/applications/ColossalEval/examples/dataset_evaluation/config/evaluation/config.json
new file mode 100644
index 000000000000..adb540f60345
--- /dev/null
+++ b/applications/ColossalEval/examples/dataset_evaluation/config/evaluation/config.json
@@ -0,0 +1,58 @@
+{
+ "model": [
+ {
+ "name": "model1"
+ },
+ {
+ "name": "model2"
+ }
+ ],
+ "dataset": [
+ {
+ "name": "mmlu",
+ "metrics": [
+ "first_token_accuracy",
+ "single_choice_accuracy",
+ "perplexity",
+ "ppl_score",
+ "ppl_score_over_choices"
+ ]
+ },
+ {
+ "name": "cmmlu",
+ "metrics": [
+ "first_token_accuracy",
+ "single_choice_accuracy",
+ "perplexity",
+ "ppl_score",
+ "ppl_score_over_choices"
+ ]
+ },
+ {
+ "name": "agieval",
+ "metrics": [
+ "first_token_accuracy",
+ "single_choice_accuracy",
+ "multi_choice_accuracy",
+ "math_equivalence",
+ "perplexity",
+ "ppl_score_over_choices",
+ "ppl_score"
+ ]
+ },
+ {
+ "name": "gaokaobench",
+ "metrics": [
+ "first_token_accuracy",
+ "single_choice_accuracy",
+ "multi_choice_accuracy",
+ "math_equivalence",
+ "rouge_score",
+ "rouge_zh_score",
+ "perplexity",
+ "ppl_score_over_choices",
+ "ppl_score"
+ ]
+ }
+ ]
+}
diff --git a/applications/ColossalEval/examples/dataset_evaluation/config/inference/config.json b/applications/ColossalEval/examples/dataset_evaluation/config/inference/config.json
new file mode 100644
index 000000000000..9672c442e647
--- /dev/null
+++ b/applications/ColossalEval/examples/dataset_evaluation/config/inference/config.json
@@ -0,0 +1,84 @@
+{
+ "model": [
+ {
+ "name": "model name",
+ "model_class": "HuggingFaceCausalLM",
+ "parameters": {
+ "path": "path to model",
+ "model_max_length": 4096,
+ "tokenizer_path": "",
+ "tokenizer_kwargs": {
+ "trust_remote_code": true
+ },
+ "peft_path": null,
+ "model_kwargs": {
+ "torch_dtype": "torch.float32",
+ "trust_remote_code": true
+ },
+ "prompt_template": "plain",
+ "batch_size": 4
+ }
+ },
+ {
+ "name": "model2 name",
+ "model_class": "HuggingFaceCausalLM",
+ "parameters": {
+ "path": "path to model2",
+ "model_max_length": 4096,
+ "tokenizer_path": "",
+ "tokenizer_kwargs": {
+ "trust_remote_code": true
+ },
+ "peft_path": null,
+ "model_kwargs": {
+ "torch_dtype": "torch.float32",
+ "trust_remote_code": true
+ },
+ "prompt_template": "plain",
+ "batch_size": 4
+ }
+ }
+ ],
+ "dataset": [
+ {
+ "name": "agieval",
+ "dataset_class": "AGIEvalDataset",
+ "debug": false,
+ "few_shot": false,
+ "path": "path to original dataset (folder)",
+ "save_path": "path to save converted dataset (e.g. inference_data/agieval.json)"
+ },
+ {
+ "name": "ceval",
+ "dataset_class": "CEvalDataset",
+ "debug": false,
+ "few_shot": true,
+ "path": "path to original dataset (folder)",
+ "save_path": "path to save converted dataset (e.g. inference_data/ceval.json)"
+ },
+ {
+ "name": "cmmlu",
+ "dataset_class": "CMMLUDataset",
+ "debug": false,
+ "few_shot": true,
+ "path": "path to original dataset (folder)",
+ "save_path": "path to save converted dataset (e.g. inference_data/cmmlu.json)"
+ },
+ {
+ "name": "gaokaobench",
+ "dataset_class": "GaoKaoBenchDataset",
+ "debug": false,
+ "few_shot": false,
+ "path": "path to original dataset (folder)",
+ "save_path": "path to save converted dataset (e.g. inference_data/gaokaobench.json)"
+ },
+ {
+ "name": "mmlu",
+ "dataset_class": "MMLUDataset",
+ "debug": false,
+ "few_shot": true,
+ "path": "path to original dataset (folder)",
+ "save_path": "path to save converted dataset (e.g. inference_data/mmlu.json)"
+ }
+ ]
+}
diff --git a/applications/ColossalEval/examples/dataset_evaluation/eval_dataset.py b/applications/ColossalEval/examples/dataset_evaluation/eval_dataset.py
new file mode 100644
index 000000000000..ec81cf0cef71
--- /dev/null
+++ b/applications/ColossalEval/examples/dataset_evaluation/eval_dataset.py
@@ -0,0 +1,73 @@
+import argparse
+import os
+
+import tabulate
+from colossal_eval.evaluate.dataset_evaluator import DatasetEvaluator
+from colossal_eval.utils import jdump, jload
+
+
+def main(args):
+ config = jload(args.config)
+
+ evaluation_results = {dataset["name"]: {} for dataset in config["dataset"]}
+ evaluation_results_table = {dataset["name"]: {} for dataset in config["dataset"]}
+ evaluator = DatasetEvaluator()
+
+ for dataset_parameter in config["dataset"]:
+ dataset_name = dataset_parameter["name"]
+ metrics = dataset_parameter["metrics"]
+ results_metric_model = {metric: {model["name"]: None for model in config["model"]} for metric in metrics}
+ for model in config["model"]:
+ model_name = model["name"]
+
+ data = jload(
+ os.path.join(args.inference_results_path, model_name, f"{dataset_name}_inference_results.json")
+ )
+ results = evaluator.get_evaluation_results(data, dataset_name, model_name, metrics)
+
+ for metric, score in results.items():
+ results_metric_model[metric][model_name] = score["ALL"]
+
+ evaluation_results[dataset_name][model_name] = results
+
+ evaluation_results_table[dataset_name] = results_metric_model
+
+ table = []
+ header = ["dataset", "metric"] + [model["name"] for model in config["model"]]
+ table.append(header)
+
+ for dataset_parameter in config["dataset"]:
+ dataset_name = dataset_parameter["name"]
+ metrics = dataset_parameter["metrics"]
+
+ for metric, model_results in evaluation_results_table[dataset_name].items():
+ row = [dataset_name]
+ for model, score in model_results.items():
+ if len(row) == 1:
+ row.extend([metric, "{:.02f}".format(score)])
+ else:
+ row.append("{:.02f}".format(score))
+
+ table.append(row)
+
+ table = tabulate.tabulate(table, headers="firstrow")
+ print(table)
+
+ os.makedirs(args.evaluation_results_save_path, exist_ok=True)
+
+ with open(os.path.join(args.evaluation_results_save_path, "evaluation_results_table.txt"), "w") as file:
+ file.write(table)
+
+ jdump(evaluation_results, os.path.join(args.evaluation_results_save_path, "evaluation_results.json"))
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="ColossalEval evaluation process.")
+ parser.add_argument("--config", type=str, default=None, required=True, help="path to config file")
+ parser.add_argument("--inference_results_path", type=str, default=None, help="path to inference results")
+ parser.add_argument(
+ "--evaluation_results_save_path", type=str, default=None, help="path to save evaluation results"
+ )
+ args = parser.parse_args()
+
+ main(args)
diff --git a/applications/ColossalEval/examples/dataset_evaluation/eval_dataset.sh b/applications/ColossalEval/examples/dataset_evaluation/eval_dataset.sh
new file mode 100644
index 000000000000..ad0bfc03acbb
--- /dev/null
+++ b/applications/ColossalEval/examples/dataset_evaluation/eval_dataset.sh
@@ -0,0 +1,4 @@
+python eval_dataset.py \
+ --config "path to config file" \
+ --inference_results_path "path to inference results" \
+ --evaluation_results_save_path "path to save evaluation results"
diff --git a/applications/ColossalEval/examples/dataset_evaluation/inference.py b/applications/ColossalEval/examples/dataset_evaluation/inference.py
new file mode 100644
index 000000000000..657fc33bf1ef
--- /dev/null
+++ b/applications/ColossalEval/examples/dataset_evaluation/inference.py
@@ -0,0 +1,171 @@
+import argparse
+import copy
+import os
+from typing import Dict, List
+
+import torch
+import torch.distributed as dist
+from colossal_eval import dataset, models, utils
+
+import colossalai
+from colossalai.logging import get_dist_logger
+
+logger = get_dist_logger()
+
+
+def rm_and_merge(world_size: int, save_path: str, model_names: List[str], dataset_names: Dict[str, List]) -> None:
+ """
+ Remove inference result per rank and merge them into one file.
+
+ Args:
+ world_size: Number of processes for inference.
+ save_path: The folder for storing inference results.
+ model_names: Names of models for inference.
+ dataset_names: Names of dataset for inference.
+
+ """
+
+ for model_name in model_names:
+ for dataset_name, categories in dataset_names.items():
+ all_answers = {}
+ for category in categories:
+ all_answers[category] = {"data": []}
+ answers = {"data": []}
+
+ for r in range(world_size):
+ directory = os.path.join(
+ save_path, model_name, f"{dataset_name}_{category}_inference_results_rank{r}.json"
+ )
+ if not os.path.exists(directory):
+ raise Exception(
+ f"Directory {directory} not found. There may be an error during inference time."
+ )
+ else:
+ rank_answers = utils.jload(directory)
+ answers["data"].extend(rank_answers["data"])
+ answers["inference_kwargs"] = rank_answers["inference_kwargs"]
+
+ for r in range(world_size):
+ try:
+ directory = os.path.join(
+ save_path, model_name, f"{dataset_name}_{category}_inference_results_rank{r}.json"
+ )
+ os.remove(directory)
+ except Exception as e:
+ print(e)
+
+ all_answers[category] = answers
+
+ logger.info(f"Save inference results of model {model_name} on dataset {dataset_name}.")
+ utils.jdump(all_answers, os.path.join(save_path, model_name, f"{dataset_name}_inference_results.json"))
+
+ logger.info(f"Save inference results of model {model_name} for all dataset.")
+ logger.info(f"Save inference results of all models for all dataset.")
+
+
+def main(args):
+ colossalai.launch_from_torch(config={}, seed=42)
+ world_size = dist.get_world_size()
+ rank = dist.get_rank()
+
+ inference_data = {}
+ debug_args = {}
+ few_shot_args = {}
+
+ config = utils.jload(args.config)
+
+ model_parameters = config["model"]
+ dataset_parameters = config["dataset"]
+
+ for dataset_parameter in dataset_parameters:
+ path = dataset_parameter["path"]
+ save_path = dataset_parameter["save_path"]
+ dataset_name = dataset_parameter["name"]
+ debug_args[dataset_name] = dataset_parameter["debug"]
+ few_shot_args[dataset_name] = dataset_parameter["few_shot"]
+
+ if not args.load_dataset:
+ if os.path.exists(save_path):
+ dataset_ = utils.jload(save_path)
+ inference_data[dataset_name] = dataset_["test"]
+ else:
+ raise Exception(
+ "Can't find the converted dataset. You may set load_dataset True to store the dataset first."
+ )
+
+ continue
+
+ dataset_class = eval(f"dataset.{dataset_parameter['dataset_class']}")
+ if not issubclass(dataset_class, dataset.BaseDataset):
+ raise ValueError(f"Dataset class {dataset_parameter['dataset_class']} is not a subclass of BaseDataset.")
+
+ dataset_ = dataset_class(path, logger, dataset_parameter["few_shot"])
+
+ dataset_.save(save_path)
+ inference_data[dataset_name] = dataset_.dataset["test"]
+
+ for model_parameter in model_parameters:
+ model_name = model_parameter["name"]
+ model_class = eval(f"models.{model_parameter['model_class']}")
+ paramerters = model_parameter["parameters"]
+ paramerters.update({"logger": logger})
+ paramerters.update({"prompt_template": utils.prompt_templates[paramerters["prompt_template"]]})
+
+ model_ = model_class(**paramerters)
+ if not issubclass(model_class, models.BaseModel):
+ raise ValueError(f"Model class {model_parameter['model_class']} is not a subclass of BaseModel.")
+
+ for dataset_name, split_data in inference_data.items():
+ start = 0
+ for category, category_data in split_data.items():
+ if few_shot_args[dataset_name] and category_data["inference_kwargs"].get("few_shot_data", None) is None:
+ raise Exception(f"Dataset {dataset_name} doesn't have few-shot data for category {category}!")
+
+ answers_to_dump = copy.deepcopy(category_data)
+ partition_size = len(category_data["data"]) // world_size
+ redundant = len(category_data["data"]) % world_size
+
+ # Ensure that the amount of data for inference is as consistent as possible across different processes.
+ lengths = [partition_size for _ in range(world_size)]
+ for j in range(redundant):
+ lengths[(j + start) % world_size] += 1
+
+ start = (start + redundant) % world_size
+
+ questions = category_data["data"][sum(lengths[0:rank]) : sum(lengths[0:rank]) + lengths[rank]]
+
+ answers_per_rank = model_.inference(
+ questions, inference_kwargs=category_data["inference_kwargs"], debug=debug_args[dataset_name]
+ )
+
+ answers_to_dump["data"] = answers_per_rank
+
+ utils.jdump(
+ answers_to_dump,
+ os.path.join(
+ args.inference_save_path,
+ model_name,
+ f"{dataset_name}_{category}_inference_results_rank{rank}.json",
+ ),
+ )
+
+ logger.info(f"Rank {rank} peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.3f} GB")
+
+ del model_
+ torch.cuda.empty_cache()
+
+ dist.barrier()
+ if rank == 0:
+ model_names = [model_parameter["name"] for model_parameter in model_parameters]
+ dataset_names = {key: list(inference_data[key].keys()) for key in inference_data}
+ rm_and_merge(world_size, args.inference_save_path, model_names, dataset_names)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="ColossalEval inference process.")
+ parser.add_argument("--config", type=str, default=None, required=True, help="path to config file")
+ parser.add_argument("--load_dataset", default=False, action="store_true")
+ parser.add_argument("--inference_save_path", type=str, default=None, help="path to save inference results")
+ args = parser.parse_args()
+
+ main(args)
diff --git a/applications/ColossalEval/examples/dataset_evaluation/inference.sh b/applications/ColossalEval/examples/dataset_evaluation/inference.sh
new file mode 100644
index 000000000000..15f9afd56045
--- /dev/null
+++ b/applications/ColossalEval/examples/dataset_evaluation/inference.sh
@@ -0,0 +1,4 @@
+torchrun --nproc_per_node=1 inference.py \
+ --config "path to config file" \
+ --load_dataset \
+ --inference_save_path "path to save inference results"
diff --git a/applications/ColossalEval/examples/gpt_evaluation/config/evaluation/config.json b/applications/ColossalEval/examples/gpt_evaluation/config/evaluation/config.json
new file mode 100644
index 000000000000..6ebe3996b1cf
--- /dev/null
+++ b/applications/ColossalEval/examples/gpt_evaluation/config/evaluation/config.json
@@ -0,0 +1,44 @@
+{
+ "language": "en",
+ "category": {
+ "brainstorming": {
+ "GPT": [
+ "language organization",
+ "relevance",
+ "creativity",
+ "practicality",
+ "reasonableness"
+ ]
+ },
+ "chat": {
+ "GPT": [
+ "language organization",
+ "naturalness",
+ "engagingness",
+ "fidelity"
+ ]
+ },
+ "generation": {
+ "GPT": [
+ "language organization",
+ "relevance",
+ "diversity"
+ ]
+ },
+ "open_qa": {
+ "GPT": [
+ "language organization",
+ "relevance",
+ "correctness"
+ ]
+ },
+ "roleplay": {
+ "GPT": [
+ "language organization",
+ "relevance",
+ "fidelity",
+ "creativity"
+ ]
+ }
+ }
+}
diff --git a/applications/ColossalEval/examples/gpt_evaluation/config/inference/config.json b/applications/ColossalEval/examples/gpt_evaluation/config/inference/config.json
new file mode 100644
index 000000000000..7ed7491a87c5
--- /dev/null
+++ b/applications/ColossalEval/examples/gpt_evaluation/config/inference/config.json
@@ -0,0 +1,33 @@
+{
+ "model": [
+ {
+ "name": "model name",
+ "model_class": "HuggingFaceCausalLM",
+ "parameters": {
+ "path": "path to model",
+ "model_max_length": 4096,
+ "tokenizer_path": "",
+ "tokenizer_kwargs": {
+ "trust_remote_code": true
+ },
+ "peft_path": null,
+ "model_kwargs": {
+ "torch_dtype": "torch.float32",
+ "trust_remote_code": true
+ },
+ "prompt_template": "plain",
+ "batch_size": 4
+ }
+ }
+ ],
+ "dataset": [
+ {
+ "name": "colossal",
+ "dataset_class": "ColossalDataset",
+ "debug": false,
+ "few_shot": false,
+ "path": "../../configs/gpt_evaluation/data/eval_en_examples.json",
+ "save_path": "path to save converted dataset (inference_data/colossal.json)"
+ }
+ ]
+}
diff --git a/applications/ColossalEval/examples/gpt_evaluation/eval.py b/applications/ColossalEval/examples/gpt_evaluation/eval.py
new file mode 100644
index 000000000000..cd521af59823
--- /dev/null
+++ b/applications/ColossalEval/examples/gpt_evaluation/eval.py
@@ -0,0 +1,139 @@
+import argparse
+import os
+
+import openai
+from colossal_eval.evaluate.evaluator import Evaluator
+from colossal_eval.utils import jload
+
+
+def main(args):
+ assert len(args.answer_file_list) == len(
+ args.model_name_list
+ ), "The number of answer files and model names should be equal!"
+
+ # load config
+ config = jload(args.config_file)
+
+ if config["language"] in ["cn", "en"]:
+ # get metric settings for all categories
+ metrics_per_category = {}
+ for category in config["category"].keys():
+ metrics_all = {}
+ for metric_type, metrics in config["category"][category].items():
+ metrics_all[metric_type] = metrics
+ metrics_per_category[category] = metrics_all
+
+ battle_prompt = None
+ if args.battle_prompt_file:
+ battle_prompt = jload(args.battle_prompt_file)
+
+ gpt_evaluation_prompt = None
+ if args.gpt_evaluation_prompt_file:
+ gpt_evaluation_prompt = jload(args.gpt_evaluation_prompt_file)
+
+ if len(args.model_name_list) == 2 and not battle_prompt:
+ raise Exception("No prompt file for battle provided. Please specify the prompt file for battle!")
+
+ if len(args.model_name_list) == 1 and not gpt_evaluation_prompt:
+ raise Exception(
+ "No prompt file for gpt evaluation provided. Please specify the prompt file for gpt evaluation!"
+ )
+
+ if args.gpt_model == "text-davinci-003" and args.gpt_with_reference:
+ raise Exception(
+ "GPT evaluation with reference is not supported for text-davinci-003. You should specify chat models such as gpt-3.5-turbo or gpt-4."
+ )
+
+ # initialize evaluator
+ evaluator = Evaluator(
+ metrics_per_category,
+ battle_prompt,
+ gpt_evaluation_prompt,
+ args.gpt_model,
+ config["language"],
+ args.gpt_with_reference,
+ )
+ if len(args.model_name_list) == 2:
+ answers_1 = jload(args.answer_file_list[0])
+ answers_2 = jload(args.answer_file_list[1])
+
+ answers1 = []
+ for category, value in answers_1.items():
+ answers1.extend(value["data"])
+
+ answers2 = []
+ for category, value in answers_2.items():
+ answers2.extend(value["data"])
+
+ assert len(answers1) == len(answers2), "The number of answers for two models should be equal!"
+
+ evaluator.battle(answers1=answers1, answers2=answers2)
+ evaluator.save(args.save_path, args.model_name_list)
+ elif len(args.model_name_list) == 1:
+ targets = jload(args.target_file)
+ answers = jload(args.answer_file_list[0])
+
+ references = []
+ for category, value in targets["test"].items():
+ references.extend(value["data"])
+
+ predictions = []
+ for category, value in answers.items():
+ predictions.extend(value["data"])
+
+ assert len(references) == len(
+ predictions
+ ), "The number of target answers and model answers should be equal!"
+
+ evaluator.evaluate(
+ answers=predictions, targets=references, save_path=args.save_path, model_name=args.model_name_list[0]
+ )
+ evaluator.save(args.save_path, args.model_name_list)
+ else:
+ raise ValueError("Unsupported number of answer files and model names!")
+ else:
+ raise ValueError(f'Unsupported language {config["language"]}!')
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="ColossalAI LLM evaluation pipeline.")
+ parser.add_argument(
+ "--config_file", type=str, default=None, required=True, help="path to the file of target results"
+ )
+ parser.add_argument("--battle_prompt_file", type=str, default=None, help="path to the prompt file for battle")
+ parser.add_argument(
+ "--gpt_evaluation_prompt_file", type=str, default=None, help="path to the prompt file for gpt evaluation"
+ )
+ parser.add_argument("--target_file", type=str, default=None, help="path to the target answer (ground truth) file")
+ parser.add_argument(
+ "--answer_file_list",
+ type=str,
+ nargs="+",
+ default=[],
+ required=True,
+ help="path to the answer files of at most 2 models",
+ )
+ parser.add_argument(
+ "--model_name_list", type=str, nargs="+", default=[], required=True, help="the names of at most 2 models"
+ )
+ parser.add_argument(
+ "--gpt_model",
+ default="gpt-3.5-turbo-16k",
+ choices=["text-davinci-003", "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4"],
+ help="which GPT model to use for evaluation",
+ )
+ parser.add_argument(
+ "--gpt_with_reference",
+ default=False,
+ action="store_true",
+ help="whether to include reference answer in gpt evaluation",
+ )
+ parser.add_argument("--save_path", type=str, default="results", help="path to save evaluation results")
+ parser.add_argument("--openai_key", type=str, default=None, required=True, help="Your openai key")
+ args = parser.parse_args()
+
+ if args.openai_key is not None:
+ os.environ["OPENAI_API_KEY"] = args.openai_key
+ openai.api_key = os.getenv("OPENAI_API_KEY")
+
+ main(args)
diff --git a/applications/Chat/evaluate/eval.sh b/applications/ColossalEval/examples/gpt_evaluation/eval.sh
old mode 100755
new mode 100644
similarity index 100%
rename from applications/Chat/evaluate/eval.sh
rename to applications/ColossalEval/examples/gpt_evaluation/eval.sh
diff --git a/applications/ColossalEval/examples/gpt_evaluation/inference.py b/applications/ColossalEval/examples/gpt_evaluation/inference.py
new file mode 100644
index 000000000000..657fc33bf1ef
--- /dev/null
+++ b/applications/ColossalEval/examples/gpt_evaluation/inference.py
@@ -0,0 +1,171 @@
+import argparse
+import copy
+import os
+from typing import Dict, List
+
+import torch
+import torch.distributed as dist
+from colossal_eval import dataset, models, utils
+
+import colossalai
+from colossalai.logging import get_dist_logger
+
+logger = get_dist_logger()
+
+
+def rm_and_merge(world_size: int, save_path: str, model_names: List[str], dataset_names: Dict[str, List]) -> None:
+ """
+ Remove inference result per rank and merge them into one file.
+
+ Args:
+ world_size: Number of processes for inference.
+ save_path: The folder for storing inference results.
+ model_names: Names of models for inference.
+ dataset_names: Names of dataset for inference.
+
+ """
+
+ for model_name in model_names:
+ for dataset_name, categories in dataset_names.items():
+ all_answers = {}
+ for category in categories:
+ all_answers[category] = {"data": []}
+ answers = {"data": []}
+
+ for r in range(world_size):
+ directory = os.path.join(
+ save_path, model_name, f"{dataset_name}_{category}_inference_results_rank{r}.json"
+ )
+ if not os.path.exists(directory):
+ raise Exception(
+ f"Directory {directory} not found. There may be an error during inference time."
+ )
+ else:
+ rank_answers = utils.jload(directory)
+ answers["data"].extend(rank_answers["data"])
+ answers["inference_kwargs"] = rank_answers["inference_kwargs"]
+
+ for r in range(world_size):
+ try:
+ directory = os.path.join(
+ save_path, model_name, f"{dataset_name}_{category}_inference_results_rank{r}.json"
+ )
+ os.remove(directory)
+ except Exception as e:
+ print(e)
+
+ all_answers[category] = answers
+
+ logger.info(f"Save inference results of model {model_name} on dataset {dataset_name}.")
+ utils.jdump(all_answers, os.path.join(save_path, model_name, f"{dataset_name}_inference_results.json"))
+
+ logger.info(f"Save inference results of model {model_name} for all dataset.")
+ logger.info(f"Save inference results of all models for all dataset.")
+
+
+def main(args):
+ colossalai.launch_from_torch(config={}, seed=42)
+ world_size = dist.get_world_size()
+ rank = dist.get_rank()
+
+ inference_data = {}
+ debug_args = {}
+ few_shot_args = {}
+
+ config = utils.jload(args.config)
+
+ model_parameters = config["model"]
+ dataset_parameters = config["dataset"]
+
+ for dataset_parameter in dataset_parameters:
+ path = dataset_parameter["path"]
+ save_path = dataset_parameter["save_path"]
+ dataset_name = dataset_parameter["name"]
+ debug_args[dataset_name] = dataset_parameter["debug"]
+ few_shot_args[dataset_name] = dataset_parameter["few_shot"]
+
+ if not args.load_dataset:
+ if os.path.exists(save_path):
+ dataset_ = utils.jload(save_path)
+ inference_data[dataset_name] = dataset_["test"]
+ else:
+ raise Exception(
+ "Can't find the converted dataset. You may set load_dataset True to store the dataset first."
+ )
+
+ continue
+
+ dataset_class = eval(f"dataset.{dataset_parameter['dataset_class']}")
+ if not issubclass(dataset_class, dataset.BaseDataset):
+ raise ValueError(f"Dataset class {dataset_parameter['dataset_class']} is not a subclass of BaseDataset.")
+
+ dataset_ = dataset_class(path, logger, dataset_parameter["few_shot"])
+
+ dataset_.save(save_path)
+ inference_data[dataset_name] = dataset_.dataset["test"]
+
+ for model_parameter in model_parameters:
+ model_name = model_parameter["name"]
+ model_class = eval(f"models.{model_parameter['model_class']}")
+ paramerters = model_parameter["parameters"]
+ paramerters.update({"logger": logger})
+ paramerters.update({"prompt_template": utils.prompt_templates[paramerters["prompt_template"]]})
+
+ model_ = model_class(**paramerters)
+ if not issubclass(model_class, models.BaseModel):
+ raise ValueError(f"Model class {model_parameter['model_class']} is not a subclass of BaseModel.")
+
+ for dataset_name, split_data in inference_data.items():
+ start = 0
+ for category, category_data in split_data.items():
+ if few_shot_args[dataset_name] and category_data["inference_kwargs"].get("few_shot_data", None) is None:
+ raise Exception(f"Dataset {dataset_name} doesn't have few-shot data for category {category}!")
+
+ answers_to_dump = copy.deepcopy(category_data)
+ partition_size = len(category_data["data"]) // world_size
+ redundant = len(category_data["data"]) % world_size
+
+ # Ensure that the amount of data for inference is as consistent as possible across different processes.
+ lengths = [partition_size for _ in range(world_size)]
+ for j in range(redundant):
+ lengths[(j + start) % world_size] += 1
+
+ start = (start + redundant) % world_size
+
+ questions = category_data["data"][sum(lengths[0:rank]) : sum(lengths[0:rank]) + lengths[rank]]
+
+ answers_per_rank = model_.inference(
+ questions, inference_kwargs=category_data["inference_kwargs"], debug=debug_args[dataset_name]
+ )
+
+ answers_to_dump["data"] = answers_per_rank
+
+ utils.jdump(
+ answers_to_dump,
+ os.path.join(
+ args.inference_save_path,
+ model_name,
+ f"{dataset_name}_{category}_inference_results_rank{rank}.json",
+ ),
+ )
+
+ logger.info(f"Rank {rank} peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.3f} GB")
+
+ del model_
+ torch.cuda.empty_cache()
+
+ dist.barrier()
+ if rank == 0:
+ model_names = [model_parameter["name"] for model_parameter in model_parameters]
+ dataset_names = {key: list(inference_data[key].keys()) for key in inference_data}
+ rm_and_merge(world_size, args.inference_save_path, model_names, dataset_names)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="ColossalEval inference process.")
+ parser.add_argument("--config", type=str, default=None, required=True, help="path to config file")
+ parser.add_argument("--load_dataset", default=False, action="store_true")
+ parser.add_argument("--inference_save_path", type=str, default=None, help="path to save inference results")
+ args = parser.parse_args()
+
+ main(args)
diff --git a/applications/ColossalEval/examples/gpt_evaluation/inference.sh b/applications/ColossalEval/examples/gpt_evaluation/inference.sh
new file mode 100644
index 000000000000..15f9afd56045
--- /dev/null
+++ b/applications/ColossalEval/examples/gpt_evaluation/inference.sh
@@ -0,0 +1,4 @@
+torchrun --nproc_per_node=1 inference.py \
+ --config "path to config file" \
+ --load_dataset \
+ --inference_save_path "path to save inference results"
diff --git a/applications/ColossalEval/requirements.txt b/applications/ColossalEval/requirements.txt
new file mode 100644
index 000000000000..c110606e0303
--- /dev/null
+++ b/applications/ColossalEval/requirements.txt
@@ -0,0 +1,12 @@
+transformers>=4.32.0
+colossalai>=0.3.1
+peft
+tabulate
+jieba
+fuzzywuzzy
+rouge
+openai
+matplotlib
+pandas
+seaborn
+scikit-learn
diff --git a/applications/ColossalEval/setup.py b/applications/ColossalEval/setup.py
new file mode 100644
index 000000000000..4f7b1bb5c42e
--- /dev/null
+++ b/applications/ColossalEval/setup.py
@@ -0,0 +1,31 @@
+from setuptools import find_packages, setup
+
+
+def fetch_requirements(path):
+ with open(path, "r") as fd:
+ return [r.strip() for r in fd.readlines()]
+
+
+def fetch_readme():
+ with open("README.md", encoding="utf-8") as f:
+ return f.read()
+
+
+setup(
+ name="colossal_eval",
+ version="0.0.1",
+ packages=find_packages(exclude=["examples", "*.egg-info"]),
+ description="Colossal-AI LLM-Evaluation Framework",
+ long_description=fetch_readme(),
+ long_description_content_type="text/markdown",
+ license="Apache Software License 2.0",
+ url="https://github.com/hpcaitech/LLM-Evaluation",
+ install_requires=fetch_requirements("requirements.txt"),
+ python_requires=">=3.6",
+ classifiers=[
+ "Programming Language :: Python :: 3",
+ "License :: OSI Approved :: Apache Software License",
+ "Environment :: GPU :: NVIDIA CUDA",
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
+ ],
+)
diff --git a/applications/README.md b/applications/README.md
index cd0435aae199..f5078e06a73b 100644
--- a/applications/README.md
+++ b/applications/README.md
@@ -4,8 +4,10 @@ This directory contains the applications that are powered by Colossal-AI.
The list of applications include:
-- [X] [Chatbot](./Chat/README.md)
-- [X] [FastFold](https://github.com/hpcaitech/FastFold): Optimizing AlphaFold (Biomedicine) Training and Inference on GPU Clusters
+- [X] [Colossal-LLaMA-2](./Colossal-LLaMA-2/): Continual Pre-training of LLaMA-2.
+- [X] [ColossalEval](./ColossalEval): Evaluation Pipeline for LLMs.
+- [X] [ColossalChat](./Chat/README.md): Replication of ChatGPT with RLHF.
+- [X] [FastFold](https://github.com/hpcaitech/FastFold): Optimizing AlphaFold (Biomedicine) Training and Inference on GPU Clusters.
> Please note that the `Chatbot` application is migrated from the original `ChatGPT` folder.
diff --git a/colossalai/__init__.py b/colossalai/__init__.py
index f859161f7810..7da55590305b 100644
--- a/colossalai/__init__.py
+++ b/colossalai/__init__.py
@@ -1,11 +1,4 @@
-from .initialize import (
- get_default_parser,
- initialize,
- launch,
- launch_from_openmpi,
- launch_from_slurm,
- launch_from_torch,
-)
+from .initialize import launch, launch_from_openmpi, launch_from_slurm, launch_from_torch
try:
# .version will be created by setup.py
@@ -13,5 +6,7 @@
except ModuleNotFoundError:
# this will only happen if the user did not run `pip install`
# and directly set PYTHONPATH to use Colossal-AI which is a bad practice
- __version__ = '0.0.0'
- print('please install Colossal-AI from https://www.colossalai.org/download or from source')
+ __version__ = "0.0.0"
+ print("please install Colossal-AI from https://www.colossalai.org/download or from source")
+
+__all__ = ["launch", "launch_from_openmpi", "launch_from_slurm", "launch_from_torch", "__version__"]
diff --git a/colossalai/_analyzer/_subclasses/_meta_registration.py b/colossalai/_analyzer/_subclasses/_meta_registration.py
index 4049be79c70f..e8ba88b0406d 100644
--- a/colossalai/_analyzer/_subclasses/_meta_registration.py
+++ b/colossalai/_analyzer/_subclasses/_meta_registration.py
@@ -3,7 +3,7 @@
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
# for more meta_registrations
-from typing import Callable, List, Optional, Tuple, Union
+from typing import List, Optional, Union
import torch
from packaging import version
@@ -24,25 +24,23 @@
def new(*args, **kwargs):
- return orig_empty(*args, **kwargs, device=torch.device('meta'))
+ return orig_empty(*args, **kwargs, device=torch.device("meta"))
def new_strided(*args, **kwargs):
- return orig_empty_strided(*args, **kwargs, device=torch.device('meta'))
+ return orig_empty_strided(*args, **kwargs, device=torch.device("meta"))
def new_like(*args, **kwargs):
- return orig_empty_like(*args, **kwargs, device=torch.device('meta'))
+ return orig_empty_like(*args, **kwargs, device=torch.device("meta"))
def register_meta(op, register_dispatcher=True):
-
def wrapper(f):
-
def add_func(op):
meta_table[op] = f
if register_dispatcher:
- name = (op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__)
+ name = op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__
try:
meta_lib.impl(name, f)
except:
@@ -54,7 +52,7 @@ def add_func(op):
return wrapper
-if version.parse(torch.__version__) >= version.parse('1.12.0'):
+if version.parse(torch.__version__) >= version.parse("1.12.0"):
# ============================== Convolutions ======================================
# https://github.com/pytorch/pytorch/pull/79834
@register_meta(aten.convolution.default)
@@ -69,7 +67,6 @@ def meta_conv(
output_padding: List[int],
groups: int,
):
-
def _formula(ln: int, p: int, d: int, k: int, s: int) -> int:
"""
Formula to apply to calculate the length of some dimension of the output
@@ -146,7 +143,8 @@ def calc_conv_nd_return_shape(
kernel_size[i],
stride[i],
output_padding_list[i],
- ))
+ )
+ )
else:
ret_shape.append(_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i]))
return ret_shape
@@ -180,19 +178,39 @@ def pick_memory_format():
shape_out = calc_conv_nd_return_shape(dims, kernel_size, stride, padding, dilation)
out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out))
mem_fmt = pick_memory_format()
- out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
+ out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
return out
@register_meta(aten._convolution.default)
- def meta__conv(input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: List[int],
- padding: List[int], dilation: List[int], is_transposed: bool, output_padding: List[int], groups: int,
- *extra_args):
+ def meta__conv(
+ input_tensor: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor,
+ stride: List[int],
+ padding: List[int],
+ dilation: List[int],
+ is_transposed: bool,
+ output_padding: List[int],
+ groups: int,
+ *extra_args,
+ ):
out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups)
return out
@register_meta(aten.convolution_backward.default)
- def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride,
- padding, dilation, transposed, output_padding, groups, output_mask):
+ def meta_conv_backward(
+ grad_output: torch.Tensor,
+ input: torch.Tensor,
+ weight: torch.Tensor,
+ bias_sizes,
+ stride,
+ padding,
+ dilation,
+ transposed,
+ output_padding,
+ groups,
+ output_mask,
+ ):
return new_like(input), new_like(weight), new((bias_sizes))
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp
@@ -224,7 +242,6 @@ def meta_cuda_rnn(
batch_sizes,
dropout_state,
):
-
is_input_packed = len(batch_sizes) != 0
if is_input_packed:
seq_length = len(batch_sizes)
@@ -240,8 +257,11 @@ def meta_cuda_rnn(
if is_input_packed:
out_shape = [batch_sizes_sum, out_size * num_directions]
else:
- out_shape = ([mini_batch, seq_length, out_size *
- num_directions] if batch_first else [seq_length, mini_batch, out_size * num_directions])
+ out_shape = (
+ [mini_batch, seq_length, out_size * num_directions]
+ if batch_first
+ else [seq_length, mini_batch, out_size * num_directions]
+ )
output = input.new_empty(out_shape)
cell_shape = [num_layers * num_directions, mini_batch, hidden_size]
@@ -257,15 +277,21 @@ def meta_cuda_rnn(
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
@register_meta(aten._cudnn_rnn_backward.default)
- def meta_cudnn_rnn_backward(input: torch.Tensor,
- weight: torch.Tensor,
- weight_stride0: int,
- hx: torch.Tensor,
- cx: Optional[torch.Tensor] = None,
- *args,
- **kwargs):
- return new_like(input), new_like(weight), new_like(hx), new_like(cx) if cx is not None else new(
- ()) # (grad_input, grad_weight, grad_hx, grad_cx)
+ def meta_cudnn_rnn_backward(
+ input: torch.Tensor,
+ weight: torch.Tensor,
+ weight_stride0: int,
+ hx: torch.Tensor,
+ cx: Optional[torch.Tensor] = None,
+ *args,
+ **kwargs,
+ ):
+ return (
+ new_like(input),
+ new_like(weight),
+ new_like(hx),
+ new_like(cx) if cx is not None else new(()),
+ ) # (grad_input, grad_weight, grad_hx, grad_cx)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp
# ============================== Activations =======================================
@@ -278,7 +304,7 @@ def meta_cudnn_rnn_backward(input: torch.Tensor,
aten.hardtanh_backward.default,
]
- if version.parse(torch.__version__) < version.parse('2.0.0'):
+ if version.parse(torch.__version__) < version.parse("2.0.0"):
_unregistered_ewise += [
aten.prelu_backward.default,
]
@@ -296,37 +322,61 @@ def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, traini
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
@register_meta(aten.native_batch_norm_backward.default)
- def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var,
- save_mean, save_invstd, train, eps, output_mask):
- return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
+ def meta_bn_backward(
+ dY: torch.Tensor,
+ input: torch.Tensor,
+ weight: torch.Tensor,
+ running_mean,
+ running_var,
+ save_mean,
+ save_invstd,
+ train,
+ eps,
+ output_mask,
+ ):
+ return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
@register_meta(aten.cudnn_batch_norm.default)
def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps):
n_input = input.size(1)
- return new_like(input), new((n_input)), new((n_input)), new(
- (0), dtype=torch.uint8) # (output, running_mean, running_var, reserve)
+ return (
+ new_like(input),
+ new((n_input)),
+ new((n_input)),
+ new((0), dtype=torch.uint8),
+ ) # (output, running_mean, running_var, reserve)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
# NB: CuDNN only implements the backward algorithm for batchnorm
# in training mode (evaluation mode batchnorm has a different algorithm),
# which is why this doesn't accept a 'training' parameter.
@register_meta(aten.cudnn_batch_norm_backward.default)
- def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var,
- save_mean, save_invstd, eps, reserve):
- return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
+ def meta_cudnn_bn_backward(
+ dY: torch.Tensor,
+ input: torch.Tensor,
+ weight: torch.Tensor,
+ running_mean,
+ running_var,
+ save_mean,
+ save_invstd,
+ eps,
+ reserve,
+ ):
+ return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
@register_meta(aten.native_layer_norm.default)
def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
bs, n_input = input.size(0), input.size(1)
- return new_like(input), new((bs, n_input, 1)), new((bs, n_input, 1)) # (output, running_mean, running_var)
+ return new_like(input), new((bs, n_input, 1)), new((bs, n_input, 1)) # (output, running_mean, running_var)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
@register_meta(aten.native_layer_norm_backward.default)
- def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias,
- grad_input_mask):
- return new_like(input), new_like(weight), new_like(bias) # (dX, dgamma, dbeta)
+ def meta_ln_backward(
+ dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias, grad_input_mask
+ ):
+ return new_like(input), new_like(weight), new_like(bias) # (dX, dgamma, dbeta)
# ================================== Misc ==========================================
# Maybe incorrect
@@ -355,8 +405,9 @@ def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Te
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp
@register_meta(aten.embedding_dense_backward.default)
- def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx,
- scale_grad_by_freq):
+ def meta_embedding_dense_backward(
+ grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx, scale_grad_by_freq
+ ):
return new((num_weights, grad_output.size(-1)), dtype=grad_output.dtype, layout=grad_output.layout)
# ============================== Dropout ===========================================
@@ -364,14 +415,14 @@ def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tens
@register_meta(aten.native_dropout.default)
def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False):
# notice that mask is bool
- return new_like(input), new_like(input, dtype=torch.bool) # (output, mask)
+ return new_like(input), new_like(input, dtype=torch.bool) # (output, mask)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
@register_meta(aten.native_dropout_backward.default)
def meta_native_dropout_backward_default(grad: torch.Tensor, mask: torch.Tensor, scale: float):
- return new_like(grad) # (grad_in)
+ return new_like(grad) # (grad_in)
- if version.parse(torch.__version__) < version.parse('1.13.0'):
+ if version.parse(torch.__version__) < version.parse("1.13.0"):
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
@register_meta(aten.eye.m_out)
def meta_eye(n: int, m: int, out: torch.Tensor):
@@ -385,24 +436,28 @@ def meta_index_Tensor(self, indices):
result: List[Optional[torch.Tensor]] = []
for i, index in enumerate(indices):
if index is not None:
- assert index.dtype in [torch.long, torch.int8, torch.bool],\
- "tensors used as indices must be long, byte or bool tensors"
+ assert index.dtype in [
+ torch.long,
+ torch.int8,
+ torch.bool,
+ ], "tensors used as indices must be long, byte or bool tensors"
if index.dtype in [torch.int8, torch.bool]:
nonzero = index.nonzero()
k = len(result)
assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}"
for j in range(index.ndim):
- assert index.shape[j] == self.shape[
- k +
- j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
+ assert (
+ index.shape[j] == self.shape[k + j]
+ ), f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
result.append(nonzero.select(1, j))
else:
result.append(index)
else:
result.append(index)
indices = result
- assert len(
- indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})"
+ assert (
+ len(indices) <= self.ndim
+ ), f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})"
# expand_outplace
import torch._refs as refs
diff --git a/colossalai/_analyzer/_subclasses/_monkey_patch.py b/colossalai/_analyzer/_subclasses/_monkey_patch.py
index b3ec98f0811f..503981409cca 100644
--- a/colossalai/_analyzer/_subclasses/_monkey_patch.py
+++ b/colossalai/_analyzer/_subclasses/_monkey_patch.py
@@ -1,5 +1,4 @@
import torch
-import torch.distributed as dist
from packaging import version
__all__ = [
@@ -48,7 +47,7 @@
"scatter",
]
-if version.parse(torch.__version__) >= version.parse('1.12.0'):
+if version.parse(torch.__version__) >= version.parse("1.12.0"):
aten = torch.ops.aten
# TODO: dive deep here
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp
diff --git a/colossalai/_analyzer/_subclasses/flop_tensor.py b/colossalai/_analyzer/_subclasses/flop_tensor.py
index 59991dc50912..9d52c5593bb8 100644
--- a/colossalai/_analyzer/_subclasses/flop_tensor.py
+++ b/colossalai/_analyzer/_subclasses/flop_tensor.py
@@ -8,7 +8,7 @@
from enum import Enum, auto
from functools import partial, reduce
from numbers import Number
-from typing import Any, Callable, List, Optional, Union
+from typing import Any, Callable, List, Union
import torch
from packaging import version
@@ -36,15 +36,15 @@ def _format_flops(flop):
B = 1e9
T = 1e12
if flop < K:
- return f'{flop:.2f}'
+ return f"{flop:.2f}"
elif flop < M:
- return f'{flop / K:.2f}K'
+ return f"{flop / K:.2f}K"
elif flop < B:
- return f'{flop / M:.2f}M'
+ return f"{flop / M:.2f}M"
elif flop < T:
- return f'{flop / B:.2f}B'
+ return f"{flop / B:.2f}B"
else:
- return f'{flop / T:.2f}T'
+ return f"{flop / T:.2f}T"
def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose: bool = False, **kwargs) -> Number:
@@ -59,11 +59,13 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
Returns:
Number: The total number of floating point operations (FWD + BWD).
"""
- maybe_inplace = (getattr(module, 'inplace', False) or kwargs.get('inplace', False)
- or getattr(module, '__name__', None) in ('add_', 'mul_', 'div_', 'sub_'))
+ maybe_inplace = (
+ getattr(module, "inplace", False)
+ or kwargs.get("inplace", False)
+ or getattr(module, "__name__", None) in ("add_", "mul_", "div_", "sub_")
+ )
class DummyModule(torch.nn.Module):
-
def __init__(self, func):
super().__init__()
self.func = func
@@ -74,21 +76,20 @@ def forward(self, *args, **kwargs):
total_flop_count = {Phase.FWD: 0, Phase.BWD: 0}
flop_counts = defaultdict(lambda: defaultdict(int))
- parents = ['Global']
+ parents = ["Global"]
module = module if isinstance(module, torch.nn.Module) else DummyModule(module)
class FlopTensor(MetaTensor):
_tensor: torch.Tensor
def __repr__(self):
- name = 'FlopParameter' if getattr(self, '_is_param', False) else 'FlopTensor'
+ name = "FlopParameter" if getattr(self, "_is_param", False) else "FlopTensor"
if self.grad_fn:
return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype}, grad_fn={self.grad_fn})"
return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})"
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
-
# no_dispatch is only needed if you use enable_python_mode.
# It prevents infinite recursion.
rs = super().__torch_dispatch__(func, types, args, kwargs)
@@ -115,9 +116,7 @@ def is_autogradable(x):
return isinstance(x, torch.Tensor) and x.is_floating_point()
def create_backwards_push(name):
-
class PushState(torch.autograd.Function):
-
@staticmethod
def forward(ctx, *args):
args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
@@ -134,9 +133,7 @@ def backward(ctx, *grad_outs):
return PushState.apply
def create_backwards_pop(name):
-
class PopState(torch.autograd.Function):
-
@staticmethod
def forward(ctx, *args):
args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
@@ -147,14 +144,13 @@ def forward(ctx, *args):
@staticmethod
def backward(ctx, *grad_outs):
nonlocal parents
- assert (parents[-1] == name)
+ assert parents[-1] == name
parents.pop()
return grad_outs
return PopState.apply
def enter_module(name):
-
def f(module, inputs):
nonlocal parents
parents.append(name)
@@ -165,10 +161,9 @@ def f(module, inputs):
return f
def exit_module(name):
-
def f(module, inputs, outputs):
nonlocal parents
- assert (parents[-1] == name)
+ assert parents[-1] == name
parents.pop()
outputs = normalize_tuple(outputs)
return create_backwards_push(name)(*outputs)
@@ -189,7 +184,7 @@ def display_flops():
for mod in flop_counts.keys():
print(f"Module: ", mod)
for k, v in flop_counts[mod].items():
- print('\t', k, _format_flops(v))
+ print("\t", k, _format_flops(v))
print()
def detach_variables(r):
@@ -201,7 +196,7 @@ def detach_variables(r):
def wrap(r):
if isinstance(r, torch.Tensor):
- data_ptr_fn = getattr(r, '_tensor', r).data_ptr
+ data_ptr_fn = getattr(r, "_tensor", r).data_ptr
r = FlopTensor(detach_variables(r))
if maybe_inplace:
r = r + 0
@@ -375,8 +370,11 @@ def norm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
# Inputs[0] contains the shape of the input.
input_shape = inputs[input_arg_index].shape
- has_affine = inputs[affine_arg_index].shape is not None if hasattr(inputs[affine_arg_index],
- 'shape') else inputs[affine_arg_index]
+ has_affine = (
+ inputs[affine_arg_index].shape is not None
+ if hasattr(inputs[affine_arg_index], "shape")
+ else inputs[affine_arg_index]
+ )
assert 2 <= len(input_shape) <= 5, input_shape
# 5 is just a rough estimate
flop = reduce(operator.mul, input_shape) * (5 if has_affine else 4)
@@ -390,7 +388,7 @@ def batchnorm_flop_jit(inputs: List[Any], outputs: List[Any], training: bool = N
training = inputs[-3]
assert isinstance(training, bool), "Signature of aten::batch_norm has changed!"
if training:
- return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore
+ return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore
has_affine = inputs[1].shape is not None
input_shape = reduce(operator.mul, inputs[0].shape)
return input_shape * (2 if has_affine else 1)
@@ -420,33 +418,30 @@ def ewise_flop(inputs: List[Any], outputs: List[Any]) -> Number:
def zero_flop_jit(*args):
"""
- Count flops for zero flop layers.
+ Count flops for zero flop layers.
"""
return 0
-if version.parse(torch.__version__) >= version.parse('1.12.0'):
+if version.parse(torch.__version__) >= version.parse("1.12.0"):
flop_mapping = {
- # gemm
+ # gemm
aten.mm.default: matmul_flop_jit,
aten.matmul.default: matmul_flop_jit,
aten.addmm.default: addmm_flop_jit,
aten.bmm.default: bmm_flop_jit,
-
- # convolution
+ # convolution
aten.convolution.default: conv_flop_jit,
aten._convolution.default: conv_flop_jit,
aten.convolution_backward.default: conv_backward_flop_jit,
-
- # normalization
+ # normalization
aten.native_batch_norm.default: batchnorm_flop_jit,
aten.native_batch_norm_backward.default: batchnorm_flop_jit,
aten.cudnn_batch_norm.default: batchnorm_flop_jit,
aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True),
aten.native_layer_norm.default: norm_flop_counter(2, 0),
aten.native_layer_norm_backward.default: norm_flop_counter(2, 0),
-
- # pooling
+ # pooling
aten.avg_pool1d.default: ewise_flop_counter(1, 0),
aten.avg_pool2d.default: ewise_flop_counter(1, 0),
aten.avg_pool2d_backward.default: ewise_flop_counter(0, 1),
@@ -469,7 +464,7 @@ def zero_flop_jit(*args):
}
ewise_flop_aten = [
- # basic op
+ # basic op
aten.add.Tensor,
aten.add_.Tensor,
aten.div.Tensor,
@@ -485,8 +480,7 @@ def zero_flop_jit(*args):
aten.sum.default,
aten.sum.dim_IntList,
aten.mean.dim,
-
- # activation op
+ # activation op
aten.hardswish.default,
aten.hardswish_.default,
aten.hardswish_backward.default,
@@ -509,15 +503,12 @@ def zero_flop_jit(*args):
aten.tanh.default,
aten.tanh_backward.default,
aten.threshold_backward.default,
-
- # dropout
+ # dropout
aten.native_dropout.default,
aten.native_dropout_backward.default,
-
- # distribution
+ # distribution
aten.bernoulli_.float,
-
- # where
+ # where
aten.where.self,
]
for op in ewise_flop_aten:
diff --git a/colossalai/_analyzer/_subclasses/meta_tensor.py b/colossalai/_analyzer/_subclasses/meta_tensor.py
index 2bc212938ee0..8be97d01343e 100644
--- a/colossalai/_analyzer/_subclasses/meta_tensor.py
+++ b/colossalai/_analyzer/_subclasses/meta_tensor.py
@@ -3,12 +3,12 @@
import torch
import torch.distributed as dist
-from torch.types import _bool, _device, _dtype
-from torch.utils._pytree import tree_flatten, tree_map
+from torch.types import _device
+from torch.utils._pytree import tree_map
from ._monkey_patch import _AliasATen, _DistCommMethod, _InplaceATen, _MaybeInplaceATen, _TorchOverrideableFactoryMethod
-__all__ = ['MetaTensor', 'MetaTensorMode']
+__all__ = ["MetaTensor", "MetaTensorMode"]
def register_storage(r, data_ptr_fn=None):
@@ -28,8 +28,7 @@ def _normalize_tuple(x):
# a hack of inplace execution in PyTorch
def _assert_alias(func):
- return func in (_AliasATen + _InplaceATen + _MaybeInplaceATen # TODO: check if should be this aggressive
- )
+ return func in (_AliasATen + _InplaceATen + _MaybeInplaceATen) # TODO: check if should be this aggressive
class MetaTensor(torch.Tensor):
@@ -65,14 +64,15 @@ def __new__(cls, elem, device=None, data_ptr_fn=None):
storage_offset=elem.storage_offset(),
dtype=elem.dtype,
layout=elem.layout,
- device=device or (elem.device if elem.device.type != 'meta' else torch.device('cpu')),
- requires_grad=requires_grad) # deceive the frontend for aten selections
+ device=device or (elem.device if elem.device.type != "meta" else torch.device("cpu")),
+ requires_grad=requires_grad,
+ ) # deceive the frontend for aten selections
r._tensor = elem
# ...the real tensor is held as an element on the tensor.
if not r._tensor.is_meta:
val = elem.data_ptr()
data_ptr_fn = lambda: val
- r._tensor = r._tensor.to(torch.device('meta'))
+ r._tensor = r._tensor.to(torch.device("meta"))
# only tensor not on `meta` should be copied to `meta`
register_storage(r._tensor, data_ptr_fn)
@@ -81,7 +81,7 @@ def __new__(cls, elem, device=None, data_ptr_fn=None):
return r
def __repr__(self):
- name = 'MetaParameter' if getattr(self, '_is_param', False) else 'MetaTensor'
+ name = "MetaParameter" if getattr(self, "_is_param", False) else "MetaTensor"
if self.grad_fn:
return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype}, grad_fn={self.grad_fn})"
return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})"
@@ -97,15 +97,15 @@ def unwrap(x):
x = x._tensor
elif isinstance(x, torch.Tensor):
device = x.device
- x = x.to(torch.device('meta'))
+ x = x.to(torch.device("meta"))
return x
args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
- if 'device' in kwargs:
- device = kwargs['device']
- kwargs['device'] = torch.device('meta')
+ if "device" in kwargs:
+ device = kwargs["device"]
+ kwargs["device"] = torch.device("meta")
# run aten for backend=CPU but actually on backend=Meta
# here we detect whether or not the execution generates a physical copy
@@ -143,21 +143,21 @@ def replace(x):
nonlocal device
if isinstance(x, str) or isinstance(x, _device):
device = x
- return torch.device('meta')
+ return torch.device("meta")
return x
elem = self._tensor.to(*tree_map(replace, args), **tree_map(replace, kwargs))
return MetaTensor(elem, device=device)
def cpu(self, *args, **kwargs):
- if self.device.type == 'cpu':
+ if self.device.type == "cpu":
return self.to(*args, **kwargs)
- return self.to(*args, device='cpu', **kwargs)
+ return self.to(*args, device="cpu", **kwargs)
def cuda(self, device=None, non_blocking=False):
if device is not None:
return self.to(device=device, non_blocking=non_blocking)
- return self.to(device='cuda:0', non_blocking=non_blocking)
+ return self.to(device="cuda:0", non_blocking=non_blocking)
def data_ptr(self):
return self._tensor.data_ptr()
@@ -177,19 +177,17 @@ class MetaTensorMode(object):
"""
def __init__(self):
- self.torch_overrides = {} # override torch.xxx
- self.dist_overrides = {} # override torch.distributed.xxx
+ self.torch_overrides = {} # override torch.xxx
+ self.dist_overrides = {} # override torch.distributed.xxx
def __enter__(self):
-
def _dummy(*args, **kwargs):
pass
def _new(*args, orig_new=torch.empty, **kwargs):
- return MetaTensor(orig_new(*args, **{
- **kwargs, 'device': 'meta'
- }),
- device=kwargs.get('device', torch.device('cpu')))
+ return MetaTensor(
+ orig_new(*args, **{**kwargs, "device": "meta"}), device=kwargs.get("device", torch.device("cpu"))
+ )
for func in _TorchOverrideableFactoryMethod:
self.torch_overrides[func] = getattr(torch, func)
diff --git a/colossalai/_analyzer/fx/codegen.py b/colossalai/_analyzer/fx/codegen.py
index 41d74f2e3719..cd244b22cac0 100644
--- a/colossalai/_analyzer/fx/codegen.py
+++ b/colossalai/_analyzer/fx/codegen.py
@@ -1,4 +1,4 @@
-from typing import Any, Callable, Dict, Iterable, List, Tuple
+from typing import Any, Dict, List, Tuple
import torch
@@ -22,7 +22,7 @@
import colossalai
from colossalai.fx._compatibility import compatibility
-_register_custom_builtin('colossalai', 'import colossalai', colossalai)
+_register_custom_builtin("colossalai", "import colossalai", colossalai)
def _gen_ckpt_fn_def(label, free_vars: List[str]) -> str:
@@ -43,17 +43,17 @@ def _gen_ckpt_usage(label, input_vars, output_vars, use_reentrant=True):
"""
Generate the checkpoint function call code text
"""
- outputs = ', '.join(output_vars)
- inputs = ', '.join(input_vars)
- return f'{outputs} = torch.utils.checkpoint.checkpoint(self.checkpoint_{label}, {inputs}, use_reentrant={use_reentrant})'
+ outputs = ", ".join(output_vars)
+ inputs = ", ".join(input_vars)
+ return f"{outputs} = torch.utils.checkpoint.checkpoint(self.checkpoint_{label}, {inputs}, use_reentrant={use_reentrant})"
def _end_of_ckpt(node: Node, ckpt_level: int) -> bool:
"""
Check if the node could end the ckpt region at `ckpt_level`
"""
- if len(node.meta['info'].activation_checkpoint) > ckpt_level:
- return node.meta['info'].activation_checkpoint[ckpt_level] is not None
+ if len(node.meta["info"].activation_checkpoint) > ckpt_level:
+ return node.meta["info"].activation_checkpoint[ckpt_level] is not None
return True
@@ -94,8 +94,8 @@ def _find_nested_ckpt_regions(node_list: List[Node], ckpt_level: int = 0):
current_region = None
for idx, node in enumerate(node_list):
- if len(node.meta['info'].activation_checkpoint) > ckpt_level:
- act_ckpt_label = node.meta['info'].activation_checkpoint[ckpt_level]
+ if len(node.meta["info"].activation_checkpoint) > ckpt_level:
+ act_ckpt_label = node.meta["info"].activation_checkpoint[ckpt_level]
# this activation checkpoint label is not set yet
# meaning this is the first node of the activation ckpt region
@@ -131,13 +131,9 @@ def _find_nested_ckpt_regions(node_list: List[Node], ckpt_level: int = 0):
return ckpt_regions
-def emit_ckpt_func(body,
- ckpt_func,
- node_list: List[Node],
- emit_node_func,
- delete_unused_value_func,
- ckpt_level=0,
- in_ckpt=False):
+def emit_ckpt_func(
+ body, ckpt_func, node_list: List[Node], emit_node_func, delete_unused_value_func, ckpt_level=0, in_ckpt=False
+):
"""Emit ckpt function in nested way
Args:
@@ -156,12 +152,12 @@ def emit_ckpt_func(body,
# label given by each layer, e.g. if you are currently at level (0, 1, 1)
# the label will be '0_1_1'
- label = "_".join([str(idx) for idx in node_list[0].meta['info'].activation_checkpoint[:ckpt_level + 1]])
+ label = "_".join([str(idx) for idx in node_list[0].meta["info"].activation_checkpoint[: ckpt_level + 1]])
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
- ckpt_func.append(f'{ckpt_fn_def}\n')
+ ckpt_func.append(f"{ckpt_fn_def}\n")
# if there is more level to fetch
- if ckpt_level + 1 < max(map(lambda node: len(node.meta['info'].activation_checkpoint), node_list)):
+ if ckpt_level + 1 < max(map(lambda node: len(node.meta["info"].activation_checkpoint), node_list)):
ckpt_regions = _find_nested_ckpt_regions(node_list, ckpt_level + 1)
start_idx = [item[0] for item in ckpt_regions]
end_idx = [item[1] for item in ckpt_regions]
@@ -174,33 +170,40 @@ def emit_ckpt_func(body,
break
if node_idx in start_idx:
- ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1]
- emit_ckpt_func(ckpt_func, ckpt_func_buffer, ckpt_node_list, emit_node_func, delete_unused_value_func,
- ckpt_level + 1, True)
+ ckpt_node_list = node_list[node_idx : end_idx[start_idx.index(node_idx)] + 1]
+ emit_ckpt_func(
+ ckpt_func,
+ ckpt_func_buffer,
+ ckpt_node_list,
+ emit_node_func,
+ delete_unused_value_func,
+ ckpt_level + 1,
+ True,
+ )
node_idx += len(ckpt_node_list)
else:
node = node_list[node_idx]
emit_node_func(node, ckpt_func)
- ckpt_func[-1] = ' ' + ckpt_func[-1]
+ ckpt_func[-1] = " " + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func)
node_idx += 1
- ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
+ ckpt_func.append(" " + _gen_ckpt_output(outputs) + "\n\n")
ckpt_func += ckpt_func_buffer
# last level
else:
for node in node_list:
emit_node_func(node, ckpt_func)
- ckpt_func[-1] = ' ' + ckpt_func[-1]
+ ckpt_func[-1] = " " + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func)
- ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
+ ckpt_func.append(" " + _gen_ckpt_output(outputs) + "\n\n")
- usage = _gen_ckpt_usage(label, inputs, outputs, False) + '\n'
+ usage = _gen_ckpt_usage(label, inputs, outputs, False) + "\n"
if in_ckpt:
- usage = ' ' + usage
+ usage = " " + usage
body.append(usage)
@@ -229,7 +232,7 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
# process ckpt_regions
if node_idx in start_idx:
- ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1]
+ ckpt_node_list = node_list[node_idx : end_idx[start_idx.index(node_idx)] + 1]
emit_ckpt_func(body, ckpt_func, ckpt_node_list, emit_node_func, delete_unused_value_func)
node_idx += len(ckpt_node_list)
@@ -243,7 +246,6 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
@compatibility(is_backward_compatible=True)
class ActivationCheckpointCodeGen(CodeGen):
-
def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode:
free_vars: List[str] = []
body: List[str] = []
@@ -251,7 +253,7 @@ def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> Py
wrapped_fns: Dict[str, None] = {}
# Wrap string in list to pass by reference
- maybe_return_annotation: List[str] = ['']
+ maybe_return_annotation: List[str] = [""]
def add_global(name_hint: str, obj: Any):
"""Add an obj to be tracked as a global.
@@ -259,7 +261,7 @@ def add_global(name_hint: str, obj: Any):
Graph, like functions or types.
Returns: the global name that should be used to reference 'obj' in generated source.
"""
- if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
+ if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
# HACK: workaround for how torch custom ops are registered. We
# can't import them like normal modules so they must retain their
# fully qualified name.
@@ -281,16 +283,16 @@ def add_global(name_hint: str, obj: Any):
def type_repr(o: Any):
if o == ():
# Empty tuple is used for empty tuple type annotation Tuple[()]
- return '()'
+ return "()"
typename = _type_repr(o)
- if hasattr(o, '__origin__'):
+ if hasattr(o, "__origin__"):
# This is a generic type, e.g. typing.List[torch.Tensor]
origin_type = _origin_type_map.get(o.__origin__, o.__origin__)
origin_typename = add_global(_type_repr(origin_type), origin_type)
- if hasattr(o, '__args__'):
+ if hasattr(o, "__args__"):
# Assign global names for each of the inner type variables.
args = [type_repr(arg) for arg in o.__args__]
@@ -309,19 +311,18 @@ def type_repr(o: Any):
return add_global(typename, o)
def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:
-
def _get_repr(arg):
# Handle NamedTuples (if it has `_fields`) via add_global.
- if isinstance(arg, tuple) and hasattr(arg, '_fields'):
+ if isinstance(arg, tuple) and hasattr(arg, "_fields"):
qualified_name = _get_qualified_name(type(arg))
global_name = add_global(qualified_name, type(arg))
return f"{global_name}{repr(tuple(arg))}"
return repr(arg)
- args_s = ', '.join(_get_repr(a) for a in args)
- kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items())
+ args_s = ", ".join(_get_repr(a) for a in args)
+ kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items())
if args_s and kwargs_s:
- return f'{args_s}, {kwargs_s}'
+ return f"{args_s}, {kwargs_s}"
return args_s or kwargs_s
# Run through reverse nodes and record the first instance of a use
@@ -347,82 +348,94 @@ def delete_unused_values(user: Node, body):
not used in the remainder of the code are freed and the memory usage
of the code is optimal.
"""
- if user.op == 'placeholder':
+ if user.op == "placeholder":
return
- if user.op == 'output':
- body.append('\n')
+ if user.op == "output":
+ body.append("\n")
return
nodes_to_delete = user_to_last_uses.get(user, [])
if len(nodes_to_delete):
- to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None'])
- body.append(f'; {to_delete_str}\n')
+ to_delete_str = " = ".join([repr(n) for n in nodes_to_delete] + ["None"])
+ body.append(f"; {to_delete_str}\n")
else:
- body.append('\n')
+ body.append("\n")
# NOTE: we add a variable to distinguish body and ckpt_func
def emit_node(node: Node, body):
- maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
- if node.op == 'placeholder':
+ maybe_type_annotation = "" if node.type is None else f" : {type_repr(node.type)}"
+ if node.op == "placeholder":
assert isinstance(node.target, str)
- maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}'
- free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}')
- raw_name = node.target.replace('*', '')
+ maybe_default_arg = "" if not node.args else f" = {repr(node.args[0])}"
+ free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}")
+ raw_name = node.target.replace("*", "")
if raw_name != repr(node):
- body.append(f'{repr(node)} = {raw_name}\n')
+ body.append(f"{repr(node)} = {raw_name}\n")
return
- elif node.op == 'call_method':
+ elif node.op == "call_method":
assert isinstance(node.target, str)
- body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}'
- f'({_format_args(node.args[1:], node.kwargs)})')
+ body.append(
+ f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}"
+ f"({_format_args(node.args[1:], node.kwargs)})"
+ )
return
- elif node.op == 'call_function':
+ elif node.op == "call_function":
assert callable(node.target)
# pretty print operators
- if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods:
+ if node.target.__module__ == "_operator" and node.target.__name__ in magic_methods:
assert isinstance(node.args, tuple)
- body.append(f'{repr(node)}{maybe_type_annotation} = '
- f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}')
+ body.append(
+ f"{repr(node)}{maybe_type_annotation} = "
+ f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}"
+ )
return
# pretty print inplace operators; required for jit.script to work properly
# not currently supported in normal FX graphs, but generated by torchdynamo
- if node.target.__module__ == '_operator' and node.target.__name__ in inplace_methods:
- body.append(f'{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; '
- f'{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}')
+ if node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods:
+ body.append(
+ f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; "
+ f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}"
+ )
return
qualified_name = _get_qualified_name(node.target)
global_name = add_global(qualified_name, node.target)
# special case for getattr: node.args could be 2-argument or 3-argument
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
- if global_name == 'getattr' and \
- isinstance(node.args, tuple) and \
- isinstance(node.args[1], str) and \
- node.args[1].isidentifier() and \
- len(node.args) == 2:
+ if (
+ global_name == "getattr"
+ and isinstance(node.args, tuple)
+ and isinstance(node.args[1], str)
+ and node.args[1].isidentifier()
+ and len(node.args) == 2
+ ):
body.append(
- f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}')
+ f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}"
+ )
return
body.append(
- f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})')
- if node.meta.get('is_wrapped', False):
+ f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})"
+ )
+ if node.meta.get("is_wrapped", False):
wrapped_fns.setdefault(global_name)
return
- elif node.op == 'call_module':
+ elif node.op == "call_module":
assert isinstance(node.target, str)
- body.append(f'{repr(node)}{maybe_type_annotation} = '
- f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})')
+ body.append(
+ f"{repr(node)}{maybe_type_annotation} = "
+ f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})"
+ )
return
- elif node.op == 'get_attr':
+ elif node.op == "get_attr":
assert isinstance(node.target, str)
- body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}')
+ body.append(f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}")
return
- elif node.op == 'output':
+ elif node.op == "output":
if node.type is not None:
maybe_return_annotation[0] = f" -> {type_repr(node.type)}"
body.append(self.generate_output(node.args[0]))
return
- raise NotImplementedError(f'node: {node.op} {node.target}')
+ raise NotImplementedError(f"node: {node.op} {node.target}")
# Modified for activation checkpointing
ckpt_func = []
@@ -432,13 +445,13 @@ def emit_node(node: Node, body):
# If the Graph has no non-placeholder nodes, no lines for the body
# have been emitted. To continue to have valid Python code, emit a
# single pass statement
- body.append('pass\n')
+ body.append("pass\n")
if len(wrapped_fns) > 0:
- wrap_name = add_global('wrap', torch.fx.wrap)
- wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns])
+ wrap_name = add_global("wrap", torch.fx.wrap)
+ wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns])
else:
- wrap_stmts = ''
+ wrap_stmts = ""
if self._body_transformer:
body = self._body_transformer(body)
@@ -447,11 +460,11 @@ def emit_node(node: Node, body):
add_global(name, value)
prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
- prologue = ''.join(ckpt_func) + prologue
+ prologue = "".join(ckpt_func) + prologue
prologue = prologue
- code = ''.join(body)
- code = '\n'.join(' ' + line for line in code.split('\n'))
+ code = "".join(body)
+ code = "\n".join(" " + line for line in code.split("\n"))
fn_code = f"""
{wrap_stmts}
{prologue}
diff --git a/colossalai/_analyzer/fx/graph_module.py b/colossalai/_analyzer/fx/graph_module.py
index 1fdedd758c01..9d3999e322b9 100644
--- a/colossalai/_analyzer/fx/graph_module.py
+++ b/colossalai/_analyzer/fx/graph_module.py
@@ -13,6 +13,7 @@
try:
from torch.fx.graph import _PyTreeCodeGen
+
SUPPORT_PT_CODEGEN = True
except ImportError:
SUPPORT_PT_CODEGEN = False
@@ -24,7 +25,6 @@
# This is a copy of torch.fx.graph_module._WrappedCall.
# It should be removed when we stop supporting torch < 1.12.0.
class _WrappedCall:
-
def __init__(self, cls, cls_call):
self.cls = cls
self.cls_call = cls_call
@@ -50,12 +50,14 @@ def _generate_error_message(frame_summary: traceback.FrameSummary) -> str:
# constituent substrings of the error message
tb_repr = traceback.format_exc()
- custom_msg = ("Call using an FX-traced Module, "
- f"line {err_lineno} of the traced Module's "
- "generated forward function:")
- before_err = "".join(all_src_lines[err_lineno - 2:err_lineno])
+ custom_msg = (
+ "Call using an FX-traced Module, "
+ f"line {err_lineno} of the traced Module's "
+ "generated forward function:"
+ )
+ before_err = "".join(all_src_lines[err_lineno - 2 : err_lineno])
marker = "~" * err_line_len + "~~~ <--- HERE"
- err_and_after_err = "\n".join(all_src_lines[err_lineno:err_lineno + 2])
+ err_and_after_err = "\n".join(all_src_lines[err_lineno : err_lineno + 2])
# joined message
return "\n".join([tb_repr, custom_msg, before_err, marker, err_and_after_err])
@@ -65,11 +67,14 @@ def __call__(self, obj, *args, **kwargs):
if self.cls_call is not None:
return self.cls_call(obj, *args, **kwargs)
else:
- return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
+ return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
except Exception as e:
assert e.__traceback__
- topmost_framesummary: traceback.FrameSummary = \
- traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1] # type: ignore[arg-type]
+ topmost_framesummary: traceback.FrameSummary = traceback.StackSummary.extract(
+ traceback.walk_tb(e.__traceback__)
+ )[
+ -1
+ ] # type: ignore[arg-type]
if "eval_with_key" in topmost_framesummary.filename:
print(_WrappedCall._generate_error_message(topmost_framesummary), file=sys.stderr)
raise e.with_traceback(None)
@@ -99,10 +104,9 @@ class ColoGraphModule(torch.fx.GraphModule):
code.
"""
- def __init__(self,
- root: Union[torch.nn.Module, Dict[str, Any]],
- graph: torch.fx.Graph,
- class_name: str = 'GraphModule'):
+ def __init__(
+ self, root: Union[torch.nn.Module, Dict[str, Any]], graph: torch.fx.Graph, class_name: str = "GraphModule"
+ ):
super().__init__(root, graph, class_name)
def bind(self, ckpt_def, globals):
@@ -134,7 +138,7 @@ def recompile(self) -> PythonCode:
if SUPPORT_PT_CODEGEN and isinstance(self._graph._codegen, _PyTreeCodeGen):
self._in_spec = self._graph._codegen.pytree_info.in_spec
self._out_spec = self._graph._codegen.pytree_info.out_spec
- python_code = self._graph.python_code(root_module='self')
+ python_code = self._graph.python_code(root_module="self")
self._code = python_code.src
# To split ckpt functions code and forward code
@@ -157,8 +161,8 @@ def recompile(self) -> PythonCode:
# bypass patching of torch.nn.Module.__call__ done while symbolic tracing.
cls_call = cls.__call__ if "__call__" in vars(cls) else None
- if '_wrapped_call' not in vars(cls):
- cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined]
+ if "_wrapped_call" not in vars(cls):
+ cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined]
def call_wrapped(self, *args, **kwargs):
return self._wrapped_call(self, *args, **kwargs)
@@ -182,7 +186,7 @@ def to_folder(self, folder: Union[str, os.PathLike], module_name: str = "FxModul
"""
folder = Path(folder)
Path(folder).mkdir(exist_ok=True)
- torch.save(self.state_dict(), folder / 'state_dict.pt')
+ torch.save(self.state_dict(), folder / "state_dict.pt")
tab = " " * 4
# we add import colossalai here
@@ -208,10 +212,10 @@ def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]:
for module_name, module in self.named_children():
module_str = _gen_model_repr(module_name, module)
if module_str is None:
- module_file = folder / f'{module_name}.pt'
+ module_file = folder / f"{module_name}.pt"
torch.save(module, module_file)
blobified_modules.append(module_name)
- module_repr = module.__repr__().replace('\r', ' ').replace('\n', ' ')
+ module_repr = module.__repr__().replace("\r", " ").replace("\n", " ")
module_str = f"torch.load(r'{module_file}') # {module_repr}"
model_str += f"{tab*2}self.{module_name} = {module_str}\n"
@@ -228,12 +232,14 @@ def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]:
model_str += f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n"
model_str += f"{_addindent(self.code, 4)}\n"
- module_file = folder / 'module.py'
+ module_file = folder / "module.py"
module_file.write_text(model_str)
- init_file = folder / '__init__.py'
- init_file.write_text('from .module import *')
+ init_file = folder / "__init__.py"
+ init_file.write_text("from .module import *")
if len(blobified_modules) > 0:
- warnings.warn("Was not able to save the following children modules as reprs -"
- f"saved as pickled files instead: {blobified_modules}")
+ warnings.warn(
+ "Was not able to save the following children modules as reprs -"
+ f"saved as pickled files instead: {blobified_modules}"
+ )
diff --git a/colossalai/_analyzer/fx/node_util.py b/colossalai/_analyzer/fx/node_util.py
index fbe8400a437e..d2671787ea63 100644
--- a/colossalai/_analyzer/fx/node_util.py
+++ b/colossalai/_analyzer/fx/node_util.py
@@ -1,9 +1,9 @@
from dataclasses import dataclass, field
-from typing import Callable, ClassVar, Dict, List, Optional, Tuple, Union
+from typing import Dict, List, Optional, Tuple, Union
import torch
-from torch.autograd.profiler_util import _format_memory, _format_time
-from torch.fx import Graph, GraphModule, Node
+from torch.autograd.profiler_util import _format_memory
+from torch.fx import Node
from colossalai._analyzer.envs import MeshConfig
@@ -85,12 +85,12 @@ class MetaInfo:
node: Node
# directory
- mod_dir: str = ''
+ mod_dir: str = ""
# ctx[data_ptr] = Tensor
# mark the storage for ctx.save_for_backward
- global_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # globally shared
- curr_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # global_ctx till this node
+ global_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # globally shared
+ curr_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # global_ctx till this node
# should be updated after each graph manipulation
# ============================== Update ====================================
@@ -100,7 +100,7 @@ class MetaInfo:
inputs: Tuple[torch.Tensor] = ()
outputs: Tuple[torch.Tensor] = ()
- is_alias: Tuple[bool] = () # whether the output is an alias of input
+ is_alias: Tuple[bool] = () # whether the output is an alias of input
# compute cost
fwd_flop: Optional[int] = 0
@@ -112,29 +112,29 @@ class MetaInfo:
# should keep the same whenever manipulated
# ============================= Invariant ==================================
- activation_checkpoint: Tuple[torch.Tensor] = () # (region_0, region_1, ...) support nested codegen
+ activation_checkpoint: Tuple[torch.Tensor] = () # (region_0, region_1, ...) support nested codegen
to_offload: Optional[bool] = False
- sharding_spec: str = 'RR'
+ sharding_spec: str = "RR"
def __new__(cls, node: Node, **kwargs):
orig_init = cls.__init__
# if initialized, return the existing one
# should disable the __init__ function
- if node.meta.get('info', None) is not None:
+ if node.meta.get("info", None) is not None:
def _dummy(self, *args, **kwargs):
- if getattr(self, '_is_init', False):
+ if getattr(self, "_is_init", False):
self._is_init = True
orig_init(self, *args, **kwargs)
cls.__init__ = orig_init
cls.__init__ = _dummy
- return node.meta['info']
+ return node.meta["info"]
return super().__new__(cls)
def __post_init__(self):
- self.node.meta['info'] = self
+ self.node.meta["info"] = self
@property
def fwd_time(self, tflops: float = MeshConfig.TFLOPS, bandwidth: float = MeshConfig.BANDWIDTH):
@@ -188,24 +188,26 @@ def backward_size(self):
return compute_size_in_bytes(self.inputs)
def __repr__(self):
- s = f'Node {self.node.name}'
+ s = f"Node {self.node.name}"
if self.parameters:
- s += f'\n\thas parameter of size {_format_memory(self.param_size)}'
+ s += f"\n\thas parameter of size {_format_memory(self.param_size)}"
if self.buffers:
- s += f'\n\thas buffer of size {_format_memory(self.buffer_size)}'
+ s += f"\n\thas buffer of size {_format_memory(self.buffer_size)}"
if self.output_size:
- s += f'\n\thas output activation of size {_format_memory(self.output_size)}'
+ s += f"\n\thas output activation of size {_format_memory(self.output_size)}"
# if self.total_size:
# s += f'\n\thas total activation of size {_format_memory(self.total_size)}'
if self.temp_size:
- s += f'\n\thas temp activation of size {_format_memory(self.temp_size)}'
+ s += f"\n\thas temp activation of size {_format_memory(self.temp_size)}"
if self.backward_size:
- s += f'\n\thas backward activation of size {_format_memory(self.backward_size)}'
- s += f'\n\tfwd_flop = {self.fwd_flop}'\
- f'\n\tbwd_flop = {self.bwd_flop}'\
- f'\n\tfwd_comm = {self.fwd_comm}'\
- f'\n\tbwd_comm = {self.bwd_comm}'\
- f'\n\tto_recompute = {self.to_recompute}'\
- f'\n\tto_offload = {self.to_offload}'\
- f'\n\tsharding_spec = {self.sharding_spec}'
+ s += f"\n\thas backward activation of size {_format_memory(self.backward_size)}"
+ s += (
+ f"\n\tfwd_flop = {self.fwd_flop}"
+ f"\n\tbwd_flop = {self.bwd_flop}"
+ f"\n\tfwd_comm = {self.fwd_comm}"
+ f"\n\tbwd_comm = {self.bwd_comm}"
+ f"\n\tto_recompute = {self.to_recompute}"
+ f"\n\tto_offload = {self.to_offload}"
+ f"\n\tsharding_spec = {self.sharding_spec}"
+ )
return s
diff --git a/colossalai/_analyzer/fx/passes/graph_profile.py b/colossalai/_analyzer/fx/passes/graph_profile.py
index c3e760b31e96..158ebce219cd 100644
--- a/colossalai/_analyzer/fx/passes/graph_profile.py
+++ b/colossalai/_analyzer/fx/passes/graph_profile.py
@@ -1,8 +1,8 @@
-from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
+from typing import Any, Dict, Iterator, List, Optional, Tuple
import torch
import torch.fx
-from torch.autograd.profiler_util import _format_memory, _format_time
+from torch.autograd.profiler_util import _format_memory
from torch.fx import GraphModule
from torch.fx.node import Argument, Node, Target
@@ -13,14 +13,14 @@
def _format_flops(flops: float) -> str:
"""Returns a formatted FLOP size string"""
if flops > 1e12:
- return f'{flops / 1e12:.2f} TFLOPs'
+ return f"{flops / 1e12:.2f} TFLOPs"
elif flops > 1e9:
- return f'{flops / 1e9:.2f} GFLOPs'
+ return f"{flops / 1e9:.2f} GFLOPs"
elif flops > 1e6:
- return f'{flops / 1e6:.2f} MFLOPs'
+ return f"{flops / 1e6:.2f} MFLOPs"
elif flops > 1e3:
- return f'{flops / 1e3:.2f} kFLOPs'
- return f'{flops} FLOPs'
+ return f"{flops / 1e3:.2f} kFLOPs"
+ return f"{flops} FLOPs"
def _denormalize_tuple(t: Tuple[int, ...]) -> Tuple[int, ...]:
@@ -42,10 +42,11 @@ class GraphProfiler(torch.fx.Interpreter):
Fetch shape argument from ``ShapeProp`` without re-executing
the ``GraphModule`` from scratch.
"""
+
_profileable = [
- 'call_function',
- 'call_module',
- 'call_method',
+ "call_function",
+ "call_module",
+ "call_method",
]
def __init__(self, module: GraphModule, garbage_collect_values: bool = True):
@@ -77,14 +78,13 @@ def run(self, *args, initial_env: Optional[Dict[Node, Any]] = None, enable_io_pr
self.args_iter: Iterator[Any] = iter(args)
for node in self.module.graph.nodes:
-
- self.run_node(node) # No need to store.
+ self.run_node(node) # No need to store.
if self.garbage_collect_values:
for to_delete in self.user_to_last_uses.get(node, []):
del self.env[to_delete]
- if node.op == 'output':
+ if node.op == "output":
output_val = self.env[node]
return self.module.graph.process_outputs(output_val) if enable_io_processing else output_val
@@ -133,9 +133,11 @@ def summary(self) -> str:
try:
from tabulate import tabulate
except ImportError:
- print("`summary` relies on the library `tabulate`, "
- "which could not be found on this machine. Run `pip "
- "install tabulate` to install the library.")
+ print(
+ "`summary` relies on the library `tabulate`, "
+ "which could not be found on this machine. Run `pip "
+ "install tabulate` to install the library."
+ )
# Build up a list of summary information for each node
node_summaries: List[List[Any]] = []
@@ -145,36 +147,38 @@ def summary(self) -> str:
node: Node
n_info = MetaInfo(node)
last_n_info = last_n_info or n_info
- node_summaries.append([
- node.op,
- str(node),
- _format_memory(n_info.accumulate_size),
- _format_memory(n_info.accumulate_size - last_n_info.accumulate_size),
- _format_memory(n_info.output_size),
- _format_memory(n_info.temp_size),
- _format_memory(n_info.param_size),
- _format_memory(n_info.backward_size),
- _format_flops(n_info.fwd_flop),
- _format_flops(n_info.bwd_flop),
- ])
+ node_summaries.append(
+ [
+ node.op,
+ str(node),
+ _format_memory(n_info.accumulate_size),
+ _format_memory(n_info.accumulate_size - last_n_info.accumulate_size),
+ _format_memory(n_info.output_size),
+ _format_memory(n_info.temp_size),
+ _format_memory(n_info.param_size),
+ _format_memory(n_info.backward_size),
+ _format_flops(n_info.fwd_flop),
+ _format_flops(n_info.bwd_flop),
+ ]
+ )
last_n_info = n_info
# Use the ``tabulate`` library to create a well-formatted table
# presenting our summary information
headers: List[str] = [
- 'Op type',
- 'Op',
- 'Accumulate size',
- 'Incremental size',
- 'Output size',
- 'Temp size',
- 'Param size',
- 'Backward size',
- 'Fwd FLOPs',
- 'Bwd FLOPs',
+ "Op type",
+ "Op",
+ "Accumulate size",
+ "Incremental size",
+ "Output size",
+ "Temp size",
+ "Param size",
+ "Backward size",
+ "Fwd FLOPs",
+ "Bwd FLOPs",
]
- return tabulate(node_summaries, headers=headers, stralign='right')
+ return tabulate(node_summaries, headers=headers, stralign="right")
class CommunicationProfiler(GraphProfiler):
@@ -222,6 +226,7 @@ class with the ``@register_flop_count_impl`` decorator:
>>> def my_fn_flop_count_impl(*args, **kwargs):
>>> return 0, 0
"""
+
_custom_flop_count_impl = {}
def run_node(self, n: torch.fx.Node) -> Any:
@@ -246,11 +251,13 @@ def run_node(self, n: torch.fx.Node) -> Any:
(
n_info.fwd_flop,
n_info.bwd_flop,
- ) = getattr(self, n.op)(n.target, args, kwargs)
+ ) = getattr(
+ self, n.op
+ )(n.target, args, kwargs)
except Exception as e:
raise RuntimeError(
- f'Error {str(e)} occurred when profiling node {n}, node.target = {n.target}. '
- f'Please refer to function\'s docstring to register the relevant profile_impl for this node!'
+ f"Error {str(e)} occurred when profiling node {n}, node.target = {n.target}. "
+ f"Please refer to function's docstring to register the relevant profile_impl for this node!"
) from e
# retain the autograd graph
@@ -259,7 +266,7 @@ def run_node(self, n: torch.fx.Node) -> Any:
return _denormalize_tuple(n_info.outputs)
- def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+ def call_function(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_function`` node and return the profiling result.
Dispatch to ``_custom_flop_count_impl`` if ``call_function`` should be
@@ -283,7 +290,7 @@ def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Di
else:
return flop_count(target, *args, **kwargs)
- def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+ def call_method(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_method`` node and return the profiling result.
@@ -301,7 +308,7 @@ def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict
assert isinstance(target, str)
return flop_count(getattr(torch.Tensor, target), *args, **kwargs)
- def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+ def call_module(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_module`` node and return the profiling result.
@@ -336,9 +343,10 @@ def graph_profile_pass(module: GraphModule, *args, verbose=False) -> GraphModule
Returns:
GraphModule: The same GraphModule with profiling information
"""
- for profiler_cls in (FlopProfiler,
- # CommunicationProfiler, # TODO: add communication profiling
- ):
+ for profiler_cls in (
+ FlopProfiler,
+ # CommunicationProfiler, # TODO: add communication profiling
+ ):
profiler = profiler_cls(module)
profiler.propagate(*args, device=_current_device(module))
diff --git a/colossalai/_analyzer/fx/passes/shape_prop.py b/colossalai/_analyzer/fx/passes/shape_prop.py
index 23e83013e02f..8d44f1d4b59d 100644
--- a/colossalai/_analyzer/fx/passes/shape_prop.py
+++ b/colossalai/_analyzer/fx/passes/shape_prop.py
@@ -54,7 +54,7 @@ def _current_device(module):
try:
return next(module.parameters()).device
except StopIteration:
- return torch.device('cpu')
+ return torch.device("cpu")
@compatibility(is_backward_compatible=False)
@@ -90,6 +90,7 @@ class ShapeProp(torch.fx.Interpreter):
>>> # do something here
>>> return torch.empty(output_shape, device=output_device)
"""
+
_custom_dispatch_func = {}
_mode = MetaTensorMode()
@@ -115,15 +116,14 @@ def run_node(self, n: torch.fx.Node) -> Any:
r = getattr(self, n.op)(n.target, args, kwargs)
def unwrap_fn(elem):
-
def _convert_meta(t: torch.Tensor):
- if t.device == 'meta':
+ if t.device == "meta":
return t
else:
- return t.to('meta')
+ return t.to("meta")
if isinstance(elem, MetaTensor):
- if getattr(self, '_is_param', False):
+ if getattr(self, "_is_param", False):
return torch.nn.Parameter(_convert_meta(elem._tensor))
return _convert_meta(elem._tensor)
@@ -139,21 +139,24 @@ def _convert_meta(t: torch.Tensor):
n_info = MetaInfo(n)
n_info.outputs = _normalize_tuple(r)
- if n.op == 'call_module':
+ if n.op == "call_module":
submod = self.fetch_attr(n.target)
n_info.parameters.update({k: MetaTensor(v) for k, v in submod.named_parameters()})
n_info.buffers.update({k: MetaTensor(v) for k, v in submod.named_buffers()})
else:
- n_info.parameters.update({
- k.name: MetaTensor(v)
- for k, v in zip(n.args, args)
- if isinstance(k, torch.fx.Node) and isinstance(v, torch.nn.Parameter)
- })
+ n_info.parameters.update(
+ {
+ k.name: MetaTensor(v)
+ for k, v in zip(n.args, args)
+ if isinstance(k, torch.fx.Node) and isinstance(v, torch.nn.Parameter)
+ }
+ )
n_info.parameters.update({k: MetaTensor(v) for k, v in kwargs.items() if isinstance(v, torch.nn.Parameter)})
- n_info.inputs = tuple(v for v in args if is_pure_tensor(v)) + \
- tuple(v for v in kwargs.values() if is_pure_tensor(v))
+ n_info.inputs = tuple(v for v in args if is_pure_tensor(v)) + tuple(
+ v for v in kwargs.values() if is_pure_tensor(v)
+ )
# align with SPMD
if isinstance(r, (tuple, list)):
@@ -168,7 +171,7 @@ def _convert_meta(t: torch.Tensor):
n_info.is_alias = _normalize_tuple(tree_map(crit, n_info.outputs))
return r
- def call_function(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
+ def call_function(self, target: "Target", args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_function`` node and return the result.
If the target of ``Node`` is registered with ``@register_shape_impl``,
@@ -197,7 +200,7 @@ def call_function(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[st
else:
return res
- def call_method(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
+ def call_method(self, target: "Target", args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_method`` node and return the result.
@@ -218,7 +221,8 @@ def call_method(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[str,
convert_to_parameter = False
if target_method in (torch.Tensor.view, torch.Tensor.transpose) and isinstance(
- args[0], torch.nn.parameter.Parameter):
+ args[0], torch.nn.parameter.Parameter
+ ):
convert_to_parameter = True
# Execute the method and return the result
assert isinstance(target, str)
diff --git a/colossalai/_analyzer/fx/symbolic_profile.py b/colossalai/_analyzer/fx/symbolic_profile.py
index dd7f22c6c98a..5732a6665f78 100644
--- a/colossalai/_analyzer/fx/symbolic_profile.py
+++ b/colossalai/_analyzer/fx/symbolic_profile.py
@@ -1,5 +1,3 @@
-import torch
-import torch.fx
from torch.fx import GraphModule
from .passes import ShapeProp, graph_profile_pass, shape_prop_pass
@@ -7,7 +5,6 @@
def register_flop_count_impl(func):
-
def wrapper(impl):
FlopProfiler._custom_flop_count_impl[func] = impl
return impl
@@ -16,7 +13,6 @@ def wrapper(impl):
def register_shape_impl(func):
-
def wrapper(impl):
ShapeProp._custom_dispatch_func[func] = impl
return impl
diff --git a/colossalai/_analyzer/fx/tracer/bias_addition.py b/colossalai/_analyzer/fx/tracer/bias_addition.py
index 1e75b47ca5b0..b8b83282b42c 100644
--- a/colossalai/_analyzer/fx/tracer/bias_addition.py
+++ b/colossalai/_analyzer/fx/tracer/bias_addition.py
@@ -12,7 +12,7 @@
__all__ = []
-@register_tracer_impl(F.linear, name='_bias_addition_impl')
+@register_tracer_impl(F.linear, name="_bias_addition_impl")
def linear_impl(input, weight, bias=None):
if bias is None:
return F.linear(input, weight)
@@ -20,116 +20,130 @@ def linear_impl(input, weight, bias=None):
return F.linear(input, weight) + bias
-@register_tracer_impl(F.conv1d, name='_bias_addition_impl')
+@register_tracer_impl(F.conv1d, name="_bias_addition_impl")
def conv1d_impl(input, weight, bias=None, stride=_single(1), padding=_single(0), dilation=_single(1), groups=1):
if bias is None:
return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
else:
return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
- (-1, 1))
+ (-1, 1)
+ )
-@register_tracer_impl(F.conv2d, name='_bias_addition_impl')
+@register_tracer_impl(F.conv2d, name="_bias_addition_impl")
def conv2d_impl(input, weight, bias=None, stride=_pair(1), padding=_pair(0), dilation=_pair(1), groups=1):
if bias is None:
return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
else:
return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
- (-1, 1, 1))
+ (-1, 1, 1)
+ )
-@register_tracer_impl(F.conv3d, name='_bias_addition_impl')
+@register_tracer_impl(F.conv3d, name="_bias_addition_impl")
def conv3d_impl(input, weight, bias=None, stride=_triple(1), padding=_triple(0), dilation=_triple(1), groups=1):
if bias is None:
return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
else:
return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
- (-1, 1, 1, 1))
-
-
-@register_tracer_impl(F.conv_transpose1d, name='_bias_addition_impl')
-def conv_transpose1d_impl(input,
- weight,
- bias=None,
- stride=_single(1),
- padding=_single(0),
- output_padding=_single(0),
- groups=1,
- dilation=_single(1)):
+ (-1, 1, 1, 1)
+ )
+
+
+@register_tracer_impl(F.conv_transpose1d, name="_bias_addition_impl")
+def conv_transpose1d_impl(
+ input,
+ weight,
+ bias=None,
+ stride=_single(1),
+ padding=_single(0),
+ output_padding=_single(0),
+ groups=1,
+ dilation=_single(1),
+):
if bias is None:
- return F.conv_transpose1d(input,
- weight,
- stride=stride,
- padding=padding,
- output_padding=output_padding,
- groups=groups,
- dilation=dilation)
+ return F.conv_transpose1d(
+ input,
+ weight,
+ stride=stride,
+ padding=padding,
+ output_padding=output_padding,
+ groups=groups,
+ dilation=dilation,
+ )
else:
- return F.conv_transpose1d(input,
- weight,
- stride=stride,
- padding=padding,
- output_padding=output_padding,
- groups=groups,
- dilation=dilation) + bias.reshape((-1, 1))
-
-
-@register_tracer_impl(F.conv_transpose2d, name='_bias_addition_impl')
-def conv_transpose2d_impl(input,
- weight,
- bias=None,
- stride=_pair(1),
- padding=_pair(0),
- output_padding=_pair(0),
- groups=1,
- dilation=_pair(1)):
+ return F.conv_transpose1d(
+ input,
+ weight,
+ stride=stride,
+ padding=padding,
+ output_padding=output_padding,
+ groups=groups,
+ dilation=dilation,
+ ) + bias.reshape((-1, 1))
+
+
+@register_tracer_impl(F.conv_transpose2d, name="_bias_addition_impl")
+def conv_transpose2d_impl(
+ input, weight, bias=None, stride=_pair(1), padding=_pair(0), output_padding=_pair(0), groups=1, dilation=_pair(1)
+):
if bias is None:
- return F.conv_transpose2d(input,
- weight,
- stride=stride,
- padding=padding,
- output_padding=output_padding,
- groups=groups,
- dilation=dilation)
+ return F.conv_transpose2d(
+ input,
+ weight,
+ stride=stride,
+ padding=padding,
+ output_padding=output_padding,
+ groups=groups,
+ dilation=dilation,
+ )
else:
- return F.conv_transpose2d(input,
- weight,
- stride=stride,
- padding=padding,
- output_padding=output_padding,
- groups=groups,
- dilation=dilation) + bias.reshape((-1, 1, 1))
-
-
-@register_tracer_impl(F.conv_transpose3d, name='_bias_addition_impl')
-def conv_transpose3d_impl(input,
- weight,
- bias=None,
- stride=_triple(1),
- padding=_triple(0),
- output_padding=_triple(0),
- groups=1,
- dilation=_triple(1)):
+ return F.conv_transpose2d(
+ input,
+ weight,
+ stride=stride,
+ padding=padding,
+ output_padding=output_padding,
+ groups=groups,
+ dilation=dilation,
+ ) + bias.reshape((-1, 1, 1))
+
+
+@register_tracer_impl(F.conv_transpose3d, name="_bias_addition_impl")
+def conv_transpose3d_impl(
+ input,
+ weight,
+ bias=None,
+ stride=_triple(1),
+ padding=_triple(0),
+ output_padding=_triple(0),
+ groups=1,
+ dilation=_triple(1),
+):
if bias is None:
- return F.conv_transpose3d(input,
- weight,
- stride=stride,
- padding=padding,
- output_padding=output_padding,
- groups=groups,
- dilation=dilation)
+ return F.conv_transpose3d(
+ input,
+ weight,
+ stride=stride,
+ padding=padding,
+ output_padding=output_padding,
+ groups=groups,
+ dilation=dilation,
+ )
else:
- return F.conv_transpose3d(input,
- weight,
- stride=stride,
- padding=padding,
- output_padding=output_padding,
- groups=groups,
- dilation=dilation) + bias.reshape((-1, 1, 1, 1))
-
-
-@register_tracer_impl(torch.addmm, name='_bias_addition_impl')
-@register_tracer_impl(torch.Tensor.addmm, name='_bias_addition_impl')
+ return F.conv_transpose3d(
+ input,
+ weight,
+ stride=stride,
+ padding=padding,
+ output_padding=output_padding,
+ groups=groups,
+ dilation=dilation,
+ ) + bias.reshape((-1, 1, 1, 1))
+
+
+@register_tracer_impl(torch.addmm, name="_bias_addition_impl")
+@register_tracer_impl(torch.Tensor.addmm, name="_bias_addition_impl")
def addmm_impl(input, mat1, mat2, beta=1, alpha=1):
if alpha != 1 and beta != 1:
return F.linear(mat1, mat2.transpose(0, 1)) * alpha + input * beta
@@ -141,8 +155,8 @@ def addmm_impl(input, mat1, mat2, beta=1, alpha=1):
return F.linear(mat1, mat2.transpose(0, 1)) + input
-@register_tracer_impl(torch.addbmm, name='_bias_addition_impl')
-@register_tracer_impl(torch.Tensor.addbmm, name='_bias_addition_impl')
+@register_tracer_impl(torch.addbmm, name="_bias_addition_impl")
+@register_tracer_impl(torch.Tensor.addbmm, name="_bias_addition_impl")
def addbmm_impl(input, batch1, batch2, beta=1, alpha=1):
if alpha != 1 and beta != 1:
return torch.bmm(batch1, batch2.transpose(1, 2)) * alpha + input * beta
diff --git a/colossalai/_analyzer/fx/tracer/custom_leaf_module.py b/colossalai/_analyzer/fx/tracer/custom_leaf_module.py
index 112c7c9637d2..ff6b55be5117 100644
--- a/colossalai/_analyzer/fx/tracer/custom_leaf_module.py
+++ b/colossalai/_analyzer/fx/tracer/custom_leaf_module.py
@@ -4,6 +4,7 @@
try:
import apex
+
register_leaf_module(apex.normalization.FusedLayerNorm)
register_leaf_module(apex.normalization.FusedRMSNorm)
register_leaf_module(apex.normalization.MixedFusedLayerNorm)
diff --git a/colossalai/_analyzer/fx/tracer/proxy.py b/colossalai/_analyzer/fx/tracer/proxy.py
index ce379efdcf0d..e3e210e7d190 100644
--- a/colossalai/_analyzer/fx/tracer/proxy.py
+++ b/colossalai/_analyzer/fx/tracer/proxy.py
@@ -1,10 +1,8 @@
import operator
-from typing import Any, Callable, Dict, Optional, Set, Union
+from typing import Any, Callable, Dict, Optional, Union
import torch
-import torch.nn as nn
-from torch.fx import Graph, Node, Proxy, Tracer
-from torch.fx.graph import _Namespace
+from torch.fx import Node, Proxy
from torch.utils._pytree import tree_map
from colossalai._analyzer._subclasses import MetaTensor
@@ -32,7 +30,7 @@ def meta_data(self, args):
def __torch_function__(cls, orig_method, types, args=(), kwargs=None):
kwargs = {} if kwargs is None else kwargs
if orig_method in cls._func_dispatch:
- impl = cls._func_dispatch.pop(orig_method) # avoid recursion
+ impl = cls._func_dispatch.pop(orig_method) # avoid recursion
proxy = impl(*args, **kwargs)
cls._func_dispatch[orig_method] = impl
return proxy
@@ -72,7 +70,7 @@ def __getattr__(self, k):
return ColoAttribute(self, k, getattr(self._meta_data, k, None))
def __setitem__(self, key, value):
- proxy = self.tracer.create_proxy('call_function', operator.setitem, (self, key, value), {})
+ proxy = self.tracer.create_proxy("call_function", operator.setitem, (self, key, value), {})
proxy.meta_data = self._meta_data
return proxy
@@ -89,7 +87,6 @@ def __isinstancecheck__(self, type):
class ColoAttribute(ColoProxy):
-
def __init__(self, root, attr: str, data=None):
self.root = root
self.attr = attr
@@ -102,11 +99,11 @@ def node(self):
# the node for attributes is added lazily, since most will just be method calls
# which do not rely on the getitem call
if self._node is None:
- self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node
+ self._node = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {}).node
return self._node
def __call__(self, *args, **kwargs):
- return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs)
+ return self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs)
def __repr__(self):
return f"ColoAttribute({self.node.name}, attr={self.attr})"
diff --git a/colossalai/_analyzer/fx/tracer/symbolic_trace.py b/colossalai/_analyzer/fx/tracer/symbolic_trace.py
index 2018863f6f5f..7884fd911c86 100644
--- a/colossalai/_analyzer/fx/tracer/symbolic_trace.py
+++ b/colossalai/_analyzer/fx/tracer/symbolic_trace.py
@@ -1,4 +1,4 @@
-from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
+from typing import Any, Callable, Dict, Optional, Union
import torch
from torch.fx import Tracer
@@ -8,6 +8,7 @@
try:
from ..codegen import ActivationCheckpointCodeGen
+
SUPPORT_ACTIVATION = True
except:
SUPPORT_ACTIVATION = False
@@ -16,7 +17,7 @@
def _default_device():
- return torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
+ return torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
def _current_device(module: torch.nn.Module):
@@ -144,10 +145,9 @@ def forward(self, x):
if meta_args:
device, orig_device = _default_device(), _current_device(root)
wrap_fn = lambda elem: MetaTensor(elem, device=device) if isinstance(elem, torch.Tensor) else elem
- graph = ColoTracer(trace_act_ckpt=trace_act_ckpt,
- bias_addition_split=bias_addition_split).trace(root.to(device),
- concrete_args=concrete_args,
- meta_args=tree_map(wrap_fn, meta_args))
+ graph = ColoTracer(trace_act_ckpt=trace_act_ckpt, bias_addition_split=bias_addition_split).trace(
+ root.to(device), concrete_args=concrete_args, meta_args=tree_map(wrap_fn, meta_args)
+ )
if trace_act_ckpt and SUPPORT_ACTIVATION:
graph.set_codegen(ActivationCheckpointCodeGen())
root.to(orig_device)
diff --git a/colossalai/_analyzer/fx/tracer/tracer.py b/colossalai/_analyzer/fx/tracer/tracer.py
index 6958a00a6a72..17dce767269d 100644
--- a/colossalai/_analyzer/fx/tracer/tracer.py
+++ b/colossalai/_analyzer/fx/tracer/tracer.py
@@ -20,11 +20,10 @@ def _truncate_suffix(s: str):
import re
# FIXME: don't know why but torch.fx always gets a suffix like '_1' in the name
- return re.sub(r'_\d+$', '', s)
+ return re.sub(r"_\d+$", "", s)
-def register_tracer_impl(func: Callable[..., Any], name: Optional[str] = '_custom_impl'):
-
+def register_tracer_impl(func: Callable[..., Any], name: Optional[str] = "_custom_impl"):
def wrapper(impl):
assert hasattr(ColoTracer, name), f"Cannot register {func.__name__} in ColoTracer.{name}"
getattr(ColoTracer, name)[func] = impl
@@ -34,7 +33,6 @@ def wrapper(impl):
def register_leaf_module_impl(module: nn.Module):
-
def wrapper(impl):
ColoTracer._custom_leaf_module_impl[module] = impl
return impl
@@ -76,7 +74,7 @@ def __init__(self, trace_act_ckpt: bool = False, bias_addition_split: bool = Fal
self.ckpt_regions = []
self.ckpt_idx = 0
- self.mod_dir = ''
+ self.mod_dir = ""
# whether the tracer should split the bias_add ops into two ops
self.bias_addition_split = bias_addition_split
@@ -87,35 +85,41 @@ def is_leaf_module(self, m: nn.Module, module_qualified_name: str) -> bool:
if self.bias_addition_split and type(m) in self._bias_addition_module and m.bias is not None:
return False
# user can specify which modules are leaf modules and which are not
- return (type(m) not in self._custom_non_leaf_module
- and (type(m) in self._custom_leaf_module or super().is_leaf_module(m, module_qualified_name)))
+ return type(m) not in self._custom_non_leaf_module and (
+ type(m) in self._custom_leaf_module or super().is_leaf_module(m, module_qualified_name)
+ )
- def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...],
- kwargs: Dict[str, Any]) -> Any:
+ def call_module(
+ self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...], kwargs: Dict[str, Any]
+ ) -> Any:
curr_dir = self.mod_dir
- self.mod_dir = 'self.' + self.path_of_module(m)
+ self.mod_dir = "self." + self.path_of_module(m)
rst = super().call_module(m, forward, args, kwargs)
self.mod_dir = curr_dir
return rst
- def proxy(self, node: Node) -> 'ColoProxy':
+ def proxy(self, node: Node) -> "ColoProxy":
return ColoProxy(node, self)
- def create_proxy(self,
- kind: str,
- target: Target,
- args: Tuple[Any, ...],
- kwargs: Dict[str, Any],
- name: Optional[str] = None,
- type_expr: Optional[Any] = None,
- proxy_factory_fn: Callable[[Node], 'Proxy'] = None):
-
+ def create_proxy(
+ self,
+ kind: str,
+ target: Target,
+ args: Tuple[Any, ...],
+ kwargs: Dict[str, Any],
+ name: Optional[str] = None,
+ type_expr: Optional[Any] = None,
+ proxy_factory_fn: Callable[[Node], "Proxy"] = None,
+ ):
proxy: ColoProxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p
- if kind == 'placeholder':
- proxy.meta_data = self.meta_args[target] if target in self.meta_args else self.concrete_args.get(
- _truncate_suffix(target), None)
- elif kind == 'get_attr':
+ if kind == "placeholder":
+ proxy.meta_data = (
+ self.meta_args[target]
+ if target in self.meta_args
+ else self.concrete_args.get(_truncate_suffix(target), None)
+ )
+ elif kind == "get_attr":
self.disable_module_getattr = True
try:
attr_itr = self.root
@@ -125,20 +129,21 @@ def create_proxy(self,
proxy.meta_data = attr_itr
finally:
self.disable_module_getattr = False
- elif kind == 'call_function':
+ elif kind == "call_function":
proxy.meta_data = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
- elif kind == 'call_method':
+ elif kind == "call_method":
self.disable_module_getattr = True
try:
- if target == '__call__':
+ if target == "__call__":
proxy.meta_data = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs))
else:
if target not in _TensorPropertyMethod:
- proxy._meta_data = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]),
- **tree_map(unwrap_fn, kwargs))
+ proxy._meta_data = getattr(unwrap_fn(args[0]), target)(
+ *tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs)
+ )
finally:
self.disable_module_getattr = False
- elif kind == 'call_module':
+ elif kind == "call_module":
mod = self.root.get_submodule(target)
self.disable_module_getattr = True
try:
@@ -158,11 +163,12 @@ def create_node(self, *args, **kwargs) -> Node:
n_info = MetaInfo(node, mod_dir=self.mod_dir, activation_checkpoint=tuple(self.ckpt_regions))
return node
- def trace(self,
- root: torch.nn.Module,
- concrete_args: Optional[Dict[str, torch.Tensor]] = None,
- meta_args: Optional[Dict[str, torch.Tensor]] = None) -> Graph:
-
+ def trace(
+ self,
+ root: torch.nn.Module,
+ concrete_args: Optional[Dict[str, torch.Tensor]] = None,
+ meta_args: Optional[Dict[str, torch.Tensor]] = None,
+ ) -> Graph:
if meta_args is None:
meta_args = {}
@@ -177,9 +183,7 @@ def trace(self,
non_concrete_arg_names = sig_names - concrete_arg_names
# update concrete args with default values
for k, v in sig.parameters.items():
- if k in sig_names - meta_arg_names and \
- k not in concrete_args and \
- v.default is not inspect.Parameter.empty:
+ if k in sig_names - meta_arg_names and k not in concrete_args and v.default is not inspect.Parameter.empty:
concrete_args[k] = v.default
def _check_arg_name_valid(names: Iterable[str]):
@@ -194,9 +198,9 @@ def _check_arg_name_valid(names: Iterable[str]):
self.meta_args = meta_args
with self._torch_factory_override(), self._tracer_override(), torch.no_grad():
- self.mod_dir = 'self'
+ self.mod_dir = "self"
self.graph = super().trace(root, concrete_args=concrete_args)
- self.mod_dir = ''
+ self.mod_dir = ""
self.graph.lint()
for node in self.graph.nodes:
@@ -266,17 +270,17 @@ def _torch_factory_override(self):
# override the torch factory functions to create a proxy when the method
# is called during ``symbolic_trace()``.
def wrap_factory_method(target):
-
@functools.wraps(target)
def wrapper(*args, **kwargs):
is_proxy = any(isinstance(p, ColoProxy) for p in args) | any(
- isinstance(p, ColoProxy) for p in kwargs.values())
+ isinstance(p, ColoProxy) for p in kwargs.values()
+ )
if is_proxy:
# if the arg is a proxy, then need to record this function called on this proxy
# e.g. torch.ones(size) where size is an input proxy
self.disable_module_getattr = True
try:
- proxy = self.create_proxy('call_function', target, args, kwargs)
+ proxy = self.create_proxy("call_function", target, args, kwargs)
finally:
self.disable_module_getattr = False
return proxy
@@ -341,10 +345,13 @@ def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cac
if attr_val is p:
if n not in parameter_proxy_cache:
kwargs = {}
- if 'proxy_factory_fn' in inspect.signature(self.create_proxy).parameters:
- kwargs['proxy_factory_fn'] = (None if not self.param_shapes_constant else
- lambda node: ColoProxy(self, node, n, attr_val))
- val_proxy = self.create_proxy('get_attr', n, (), {}, **kwargs) # type: ignore[arg-type]
+ if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters:
+ kwargs["proxy_factory_fn"] = (
+ None
+ if not self.param_shapes_constant
+ else lambda node: ColoProxy(self, node, n, attr_val)
+ )
+ val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
parameter_proxy_cache[n] = val_proxy
return parameter_proxy_cache[n]
return None
@@ -355,8 +362,9 @@ def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cac
return maybe_buffer_proxy
if isinstance(attr_val, torch.nn.Parameter):
- maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(),
- parameter_proxy_cache)
+ maybe_parameter_proxy = maybe_get_proxy_for_attr(
+ attr_val, self.root.named_parameters(), parameter_proxy_cache
+ )
if maybe_parameter_proxy is not None:
return maybe_parameter_proxy
diff --git a/colossalai/amp/__init__.py b/colossalai/amp/__init__.py
index 963215476b6b..e69de29bb2d1 100644
--- a/colossalai/amp/__init__.py
+++ b/colossalai/amp/__init__.py
@@ -1,54 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-import torch.nn as nn
-from torch.nn.modules.loss import _Loss
-from torch.optim import Optimizer
-
-from colossalai.context import Config
-
-from .amp_type import AMP_TYPE
-from .apex_amp import convert_to_apex_amp
-from .naive_amp import convert_to_naive_amp
-from .torch_amp import convert_to_torch_amp
-
-__all__ = ['convert_to_amp', 'convert_to_naive_amp', 'convert_to_apex_amp', 'convert_to_torch_amp', 'AMP_TYPE']
-
-
-def convert_to_amp(model: nn.Module, optimizer: Optimizer, criterion: _Loss, mode: AMP_TYPE, amp_config: Config = None):
- """A helper function to wrap training components with Torch AMP modules.
-
- Args:
- param model (:class:`torch.nn.Module`): your model object.
- optimizer (:class:`torch.optim.Optimizer`): your optimizer object.
- criterion (:class:`torch.nn.modules.loss._Loss`): your loss function object.
- mode (:class:`colossalai.amp.AMP_TYPE`): amp mode.
- amp_config (Union[:class:`colossalai.context.Config`, dict]): configuration for different amp modes.
-
- Returns:
- A tuple (model, optimizer, criterion).
-
- Note:
- ``amp_config`` may vary from different mode you choose. You should check the corresponding amp mode
- for more details about ``amp_config``.
- For ``apex_amp``, please check
- `apex_amp config `_.
- For ``naive_amp``, please check
- `naive_amp config `_.
- For ``torch_amp``, please check
- `torch_amp config `_.
- """
- assert isinstance(mode, AMP_TYPE), \
- f'expected the argument mode be AMP_TYPE, but got {type(mode)}'
-
- if amp_config is None:
- amp_config = Config()
-
- if mode == AMP_TYPE.TORCH:
- model, optimizer, criterion = convert_to_torch_amp(model, optimizer, criterion, amp_config)
- elif mode == AMP_TYPE.APEX:
- model, optimizer = convert_to_apex_amp(model, optimizer, amp_config)
- elif mode == AMP_TYPE.NAIVE:
- model, optimizer = convert_to_naive_amp(model, optimizer, amp_config)
-
- return model, optimizer, criterion
diff --git a/colossalai/amp/naive_amp/__init__.py b/colossalai/amp/naive_amp/__init__.py
index 5b2f71d3ced7..e69de29bb2d1 100644
--- a/colossalai/amp/naive_amp/__init__.py
+++ b/colossalai/amp/naive_amp/__init__.py
@@ -1,60 +0,0 @@
-import inspect
-
-import torch.nn as nn
-from torch.optim import Optimizer
-
-from colossalai.utils import is_no_pp_or_last_stage
-
-from ._fp16_optimizer import FP16Optimizer
-from .grad_scaler import ConstantGradScaler, DynamicGradScaler
-from .naive_amp import NaiveAMPModel, NaiveAMPOptimizer
-
-
-def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config):
- """A helper function to wrap training components with naive AMP modules. In this mode,
- we forcibly cast the model weights and inputs to FP16, and cast the model outputs to FP32 to calculate loss,
- which is equivalent to Apex O3.
-
- Args:
- model (:class:`torch.nn.Module`): your model object
- optimizer (:class:`torch.optim.Optimizer`): your optimizer object
- amp_config (:class:`colossalai.context.Config` or dict): configuration for naive mode amp.
-
- Returns:
- Tuple: A tuple (model, optimizer)
-
- The ``amp_config`` should contain parameters below::
-
- verbose (bool, optional): if set to `True`, will print debug info (Default: False).
- clip_grad_norm (float, optional): clip gradients with this global L2 norm (Default 0).
- Note that clipping is ignored if clip_grad == 0.
- dynamic_grad_scale (bool): whether to use dynamic grad scaler.
- """
- if isinstance(model, nn.ModuleList):
- # interleaved pipeline
- module_list = []
- for chunk, m in enumerate(model):
- output_to_fp32 = is_no_pp_or_last_stage() and chunk == len(model) - 1
- module_list.append(NaiveAMPModel(m, output_to_fp32=output_to_fp32))
- model = nn.ModuleList(module_list)
- else:
- output_to_fp32 = is_no_pp_or_last_stage()
- model = NaiveAMPModel(model, output_to_fp32=output_to_fp32)
-
- use_dynamic_grad_scaler = amp_config.pop('dynamic_grad_scale', True)
- if use_dynamic_grad_scaler:
- scaler_class = DynamicGradScaler
- else:
- scaler_class = ConstantGradScaler
-
- sig = inspect.signature(scaler_class.__init__)
- kwargs = dict()
- for param in sig.parameters.values():
- if param.name in amp_config:
- kwargs[param.name] = amp_config.pop(param.name)
- grad_scaler = scaler_class(**kwargs)
- optimizer = NaiveAMPOptimizer(optimizer, grad_scaler, **amp_config)
- return model, optimizer
-
-
-__all__ = ['convert_to_naive_amp', 'NaiveAMPOptimizer', 'FP16Optimizer']
diff --git a/colossalai/amp/naive_amp/grad_scaler/__init__.py b/colossalai/amp/naive_amp/grad_scaler/__init__.py
index dc8499d877e1..34a20e8d67d6 100644
--- a/colossalai/amp/naive_amp/grad_scaler/__init__.py
+++ b/colossalai/amp/naive_amp/grad_scaler/__init__.py
@@ -2,4 +2,4 @@
from .constant_grad_scaler import ConstantGradScaler
from .dynamic_grad_scaler import DynamicGradScaler
-__all__ = ['BaseGradScaler', 'ConstantGradScaler', 'DynamicGradScaler']
+__all__ = ["BaseGradScaler", "ConstantGradScaler", "DynamicGradScaler"]
diff --git a/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py
index 0d84384a7f67..79661a44424f 100644
--- a/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py
+++ b/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py
@@ -9,7 +9,7 @@
from colossalai.logging import get_dist_logger
-__all__ = ['BaseGradScaler']
+__all__ = ["BaseGradScaler"]
class BaseGradScaler(ABC):
@@ -30,24 +30,21 @@ def __init__(self, initial_scale: float, verbose: bool):
@property
def scale(self) -> Tensor:
- """Returns the loss scale.
- """
+ """Returns the loss scale."""
return self._scale
@property
def inv_scale(self) -> Tensor:
- """Returns the inverse of the loss scale.
- """
+ """Returns the inverse of the loss scale."""
return self._scale.double().reciprocal().float()
def state_dict(self) -> Dict:
- """Returns the states of the gradient scaler as a dict object.
- """
+ """Returns the states of the gradient scaler as a dict object."""
state_dict = dict()
- state_dict['scale'] = self.scale
+ state_dict["scale"] = self.scale
return state_dict
def load_state_dict(self, state_dict: Dict) -> None:
@@ -57,7 +54,7 @@ def load_state_dict(self, state_dict: Dict) -> None:
state_dict (dict): the states of the gradient scaler
"""
- self._scale = state_dict['scale']
+ self._scale = state_dict["scale"]
@abstractmethod
def update(self, overflow: bool) -> None:
@@ -67,8 +64,6 @@ def update(self, overflow: bool) -> None:
overflow (bool): whether overflow occurs
"""
- pass
-
def log(self, message, *args, **kwargs):
"""Log messages.
diff --git a/colossalai/amp/naive_amp/grad_scaler/constant_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/constant_grad_scaler.py
index a2f518c5dd28..2ad8b51ac22c 100644
--- a/colossalai/amp/naive_amp/grad_scaler/constant_grad_scaler.py
+++ b/colossalai/amp/naive_amp/grad_scaler/constant_grad_scaler.py
@@ -2,7 +2,7 @@
# -*- encoding: utf-8 -*-
from .base_grad_scaler import BaseGradScaler
-__all__ = ['ConstantGradScaler']
+__all__ = ["ConstantGradScaler"]
class ConstantGradScaler(BaseGradScaler):
@@ -23,4 +23,3 @@ def update(self, overflow: bool) -> None:
Args:
overflow (bool): whether overflow occurs
"""
- pass
diff --git a/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py
index e899b9ca4c89..65133a4b3712 100644
--- a/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py
+++ b/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py
@@ -7,7 +7,7 @@
from .base_grad_scaler import BaseGradScaler
-__all__ = ['DynamicGradScaler']
+__all__ = ["DynamicGradScaler"]
class DynamicGradScaler(BaseGradScaler):
@@ -24,15 +24,17 @@ class DynamicGradScaler(BaseGradScaler):
verbose (bool): whether to log messages, defaults to False
"""
- def __init__(self,
- initial_scale: float = 2**16,
- growth_factor: float = 2,
- backoff_factor: float = 0.5,
- growth_interval: int = 1000,
- min_scale: Optional[float] = None,
- max_scale: Optional[float] = None,
- hysteresis: int = 2,
- verbose: bool = False):
+ def __init__(
+ self,
+ initial_scale: float = 2**16,
+ growth_factor: float = 2,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 1000,
+ min_scale: Optional[float] = None,
+ max_scale: Optional[float] = None,
+ hysteresis: int = 2,
+ verbose: bool = False,
+ ):
super().__init__(initial_scale, verbose)
if min_scale:
self._min_scale = torch.cuda.FloatTensor([min_scale])
@@ -53,18 +55,17 @@ def __init__(self,
self._sanity_checks()
def _sanity_checks(self) -> None:
- """Check if the arguments are correct.
- """
+ """Check if the arguments are correct."""
if self._min_scale:
- assert self._min_scale > 0, 'The minimum gradient scale cannot be zero or negative'
- assert self._min_scale <= self._scale, 'The minimum gradient scale cannot be greater than the current scale'
+ assert self._min_scale > 0, "The minimum gradient scale cannot be zero or negative"
+ assert self._min_scale <= self._scale, "The minimum gradient scale cannot be greater than the current scale"
if self._max_scale:
- assert self._max_scale > 0, 'The maximum gradient scale cannot be zero or negative'
- assert self._max_scale >= self._scale, 'The maximum gradient scale cannot be smaller than the current scale'
- assert self._growth_factor > 1, 'The growth factor cannot be equal or smaller than 1'
- assert 0 < self._backoff_factor < 1, 'The backoff factor must be between 0 and 1'
- assert self._hysteresis >= 0, 'The hysteresis cannot be negative'
+ assert self._max_scale > 0, "The maximum gradient scale cannot be zero or negative"
+ assert self._max_scale >= self._scale, "The maximum gradient scale cannot be smaller than the current scale"
+ assert self._growth_factor > 1, "The growth factor cannot be equal or smaller than 1"
+ assert 0 < self._backoff_factor < 1, "The backoff factor must be between 0 and 1"
+ assert self._hysteresis >= 0, "The hysteresis cannot be negative"
def update(self, overflow: bool) -> None:
"""Update the loss scale.
@@ -88,19 +89,18 @@ def update(self, overflow: bool) -> None:
self.log(
f"No overflow for consecutive {self._growth_interval} steps, "
f"the loss scale is adjusted to {self.scale.item()}",
- ranks=[0])
+ ranks=[0],
+ )
def _backoff_scale(self) -> None:
- """Decrease the loss scale
- """
+ """Decrease the loss scale"""
self._scale = self._scale * self._backoff_factor
if self._min_scale:
self._scale = torch.max(self._scale, self._min_scale)
def _grow_scale(self) -> None:
- """Increase the loss scale
- """
+ """Increase the loss scale"""
self._scale = self._scale * self._growth_factor
if self._max_scale:
@@ -108,14 +108,14 @@ def _grow_scale(self) -> None:
def state_dict(self):
state_dict = dict()
- state_dict['scale'] = self._scale
- state_dict['growth_factor'] = self._growth_factor
- state_dict['backoff_factor'] = self._backoff_factor
- state_dict['hysteresis'] = self._hysteresis
+ state_dict["scale"] = self._scale
+ state_dict["growth_factor"] = self._growth_factor
+ state_dict["backoff_factor"] = self._backoff_factor
+ state_dict["hysteresis"] = self._hysteresis
return state_dict
def load_state_dict(self, state_dict):
- self._scale = state_dict['scale'].cuda(torch.cuda.current_device())
- self._growth_factor = state_dict['growth_factor']
- self._backoff_factor = state_dict['backoff_factor']
- self._hysteresis = state_dict['hysteresis']
+ self._scale = state_dict["scale"].cuda(torch.cuda.current_device())
+ self._growth_factor = state_dict["growth_factor"]
+ self._backoff_factor = state_dict["backoff_factor"]
+ self._hysteresis = state_dict["hysteresis"]
diff --git a/colossalai/amp/naive_amp/mixed_precision_mixin/__init__.py b/colossalai/amp/naive_amp/mixed_precision_mixin/__init__.py
index b0348e1477bb..a31811e4a567 100644
--- a/colossalai/amp/naive_amp/mixed_precision_mixin/__init__.py
+++ b/colossalai/amp/naive_amp/mixed_precision_mixin/__init__.py
@@ -3,7 +3,7 @@
from .fp16 import FP16MixedPrecisionMixin
__all__ = [
- 'MixedPrecisionMixin',
- 'FP16MixedPrecisionMixin',
- 'BF16MixedPrecisionMixin',
+ "MixedPrecisionMixin",
+ "FP16MixedPrecisionMixin",
+ "BF16MixedPrecisionMixin",
]
diff --git a/colossalai/amp/naive_amp/mixed_precision_mixin/base.py b/colossalai/amp/naive_amp/mixed_precision_mixin/base.py
index a52a9747ad1e..fc7e0b74179a 100644
--- a/colossalai/amp/naive_amp/mixed_precision_mixin/base.py
+++ b/colossalai/amp/naive_amp/mixed_precision_mixin/base.py
@@ -39,6 +39,7 @@ def zero_grad(self):
return self.optim.zero_grad()
```
"""
+
dtype: torch.dtype
@abstractmethod
@@ -51,7 +52,6 @@ def pre_backward(self, loss: Tensor) -> Tensor:
Returns:
Tensor: Loss value (possibly scaled).
"""
- pass
@abstractmethod
def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor:
@@ -64,7 +64,6 @@ def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor:
Returns:
Tensor: Gradient of the tensor (possibly scaled).
"""
- pass
@abstractmethod
def should_skip_step(self) -> bool:
@@ -73,13 +72,10 @@ def should_skip_step(self) -> bool:
Returns:
bool: Whether to skip the step.
"""
- pass
@abstractmethod
def pre_zero_grad(self) -> None:
- """Called before zero_grad.
- """
- pass
+ """Called before zero_grad."""
@abstractmethod
def get_grad_div_scale(self) -> float:
@@ -88,4 +84,3 @@ def get_grad_div_scale(self) -> float:
Returns:
float: A divisor for gradient clipping or step.
"""
- pass
diff --git a/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py b/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py
index 1ce8e42eb3ed..9ce272356797 100644
--- a/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py
+++ b/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py
@@ -19,22 +19,26 @@ class OptimState(Enum):
class FP16MixedPrecisionMixin(MixedPrecisionMixin):
dtype = torch.float16
- def __init__(self,
- initial_scale: float = 2**16,
- min_scale: float = 1,
- growth_factor: float = 2,
- backoff_factor: float = 0.5,
- growth_interval: int = 1000,
- hysteresis: int = 2,
- max_scale: float = 2**32) -> None:
+ def __init__(
+ self,
+ initial_scale: float = 2**16,
+ min_scale: float = 1,
+ growth_factor: float = 2,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 1000,
+ hysteresis: int = 2,
+ max_scale: float = 2**32,
+ ) -> None:
super().__init__()
- self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
- min_scale=min_scale,
- growth_factor=growth_factor,
- backoff_factor=backoff_factor,
- growth_interval=growth_interval,
- hysteresis=hysteresis,
- max_scale=max_scale)
+ self.grad_scaler = DynamicGradScaler(
+ initial_scale=initial_scale,
+ min_scale=min_scale,
+ growth_factor=growth_factor,
+ backoff_factor=backoff_factor,
+ growth_interval=growth_interval,
+ hysteresis=hysteresis,
+ max_scale=max_scale,
+ )
self.optim_state = OptimState.UNSCALED
self.found_overflow = torch.zeros(1, dtype=torch.float, device=get_current_device())
@@ -49,7 +53,6 @@ def check_local_overflow(self) -> bool:
Returns:
bool: Whether there is overflow in the local process.
"""
- pass
def check_overflow(self) -> bool:
# clear previous overflow record
@@ -79,6 +82,6 @@ def pre_zero_grad(self) -> None:
pass
def get_grad_div_scale(self) -> float:
- assert self.optim_state == OptimState.SCALED, 'grads should be scaled before clipping'
+ assert self.optim_state == OptimState.SCALED, "grads should be scaled before clipping"
self.optim_state = OptimState.UNSCALED
return self.loss_scale
diff --git a/colossalai/amp/naive_amp/mixed_precision_optimizer.py b/colossalai/amp/naive_amp/mixed_precision_optimizer.py
index 626a00c96d04..501a843f6992 100644
--- a/colossalai/amp/naive_amp/mixed_precision_optimizer.py
+++ b/colossalai/amp/naive_amp/mixed_precision_optimizer.py
@@ -2,7 +2,7 @@
import torch
from torch import Tensor
-from torch.nn import Parameter
+from torch.nn import Module, Parameter
from torch.optim import Optimizer
from colossalai.interface import OptimizerWrapper
@@ -11,18 +11,20 @@
class NaiveFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
-
- def __init__(self,
- working_params: List[Parameter],
- initial_scale: float = 2**16,
- min_scale: float = 1,
- growth_factor: float = 2,
- backoff_factor: float = 0.5,
- growth_interval: int = 1000,
- hysteresis: int = 2,
- max_scale: float = 2**32) -> None:
- super().__init__(initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis,
- max_scale)
+ def __init__(
+ self,
+ working_params: List[Parameter],
+ initial_scale: float = 2**16,
+ min_scale: float = 1,
+ growth_factor: float = 2,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 1000,
+ hysteresis: int = 2,
+ max_scale: float = 2**32,
+ ) -> None:
+ super().__init__(
+ initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, max_scale
+ )
self.params = working_params
def check_local_overflow(self) -> bool:
@@ -33,38 +35,41 @@ def check_local_overflow(self) -> bool:
class MixedPrecisionOptimizer(OptimizerWrapper):
-
- def __init__(self,
- optim: Optimizer,
- precision: str = 'fp16',
- initial_scale: float = 2**16,
- min_scale: float = 1,
- growth_factor: float = 2,
- backoff_factor: float = 0.5,
- growth_interval: int = 1000,
- hysteresis: int = 2,
- max_scale: float = 2**32,
- max_norm: float = 0.0):
+ def __init__(
+ self,
+ optim: Optimizer,
+ precision: str = "fp16",
+ initial_scale: float = 2**16,
+ min_scale: float = 1,
+ growth_factor: float = 2,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 1000,
+ hysteresis: int = 2,
+ max_scale: float = 2**32,
+ max_norm: float = 0.0,
+ ):
super().__init__(optim)
- if precision == 'fp16':
+ if precision == "fp16":
working_params = []
for group in self.optim.param_groups:
- for p in group['params']:
+ for p in group["params"]:
working_params.append(p)
- self.mixed_precision = NaiveFP16MixedPrecisionMixin(working_params,
- initial_scale=initial_scale,
- min_scale=min_scale,
- growth_factor=growth_factor,
- backoff_factor=backoff_factor,
- growth_interval=growth_interval,
- hysteresis=hysteresis,
- max_scale=max_scale)
- elif precision == 'bf16':
+ self.mixed_precision = NaiveFP16MixedPrecisionMixin(
+ working_params,
+ initial_scale=initial_scale,
+ min_scale=min_scale,
+ growth_factor=growth_factor,
+ backoff_factor=backoff_factor,
+ growth_interval=growth_interval,
+ hysteresis=hysteresis,
+ max_scale=max_scale,
+ )
+ elif precision == "bf16":
self.mixed_precision = BF16MixedPrecisionMixin()
else:
- raise ValueError(f'Unsupported precision: {precision}')
+ raise ValueError(f"Unsupported precision: {precision}")
if max_norm > 0.0:
- raise NotImplementedError('max_norm is not supported yet.')
+ raise NotImplementedError("max_norm is not supported yet.")
self.max_norm = max_norm
self.working_to_master_map: Dict[Parameter, Tensor] = {}
self.master_to_working_map: Dict[Tensor, Parameter] = {}
@@ -72,7 +77,7 @@ def __init__(self,
# create master weights
for group in self.optim.param_groups:
master_params = []
- for p in group['params']:
+ for p in group["params"]:
if p.requires_grad:
master_p = p
if p.dtype != torch.float:
@@ -80,7 +85,7 @@ def __init__(self,
self.working_to_master_map[p] = master_p
self.master_to_working_map[master_p] = p
master_params.append(master_p)
- group['params'] = master_params
+ group["params"] = master_params
def backward(self, loss: Tensor, *args, **kwargs):
loss = self.mixed_precision.pre_backward(loss)
@@ -101,24 +106,24 @@ def _unscale_and_clip_grads(self, total_norm: float) -> None:
if self.mixed_precision is not None:
div_scale = self.mixed_precision.get_grad_div_scale()
- if self.max_norm > 0.:
+ if self.max_norm > 0.0:
# norm is in fact norm*scale
clip = ((total_norm / div_scale) + 1e-6) / self.max_norm
if clip > 1:
div_scale = clip * div_scale
for group in self.param_groups:
- for p in group['params']:
+ for p in group["params"]:
if p.grad is None:
continue
- p.grad.data.mul_(1. / div_scale)
+ p.grad.data.mul_(1.0 / div_scale)
def _compute_grad_norm(self) -> float:
- if self.max_norm <= 0.:
- return 0.
- grads = [p.grad for group in self.param_groups for p in group['params'] if p.grad is not None]
+ if self.max_norm <= 0.0:
+ return 0.0
+ grads = [p.grad for group in self.param_groups for p in group["params"] if p.grad is not None]
if len(grads) == 0:
- return 0.
+ return 0.0
device = grads[0].device
# TODO(ver217): support tp
total_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2).to(device) for g in grads]), 2)
@@ -130,7 +135,7 @@ def step(self, *args, **kwargs):
return
# prepare grads
for group in self.optim.param_groups:
- for p in group['params']:
+ for p in group["params"]:
working_param = self.master_to_working_map[p]
if p is working_param:
continue
@@ -142,8 +147,23 @@ def step(self, *args, **kwargs):
self.optim.step(*args, **kwargs)
# update working params
for group in self.optim.param_groups:
- for p in group['params']:
+ for p in group["params"]:
working_param = self.master_to_working_map[p]
if p is working_param:
continue
working_param.data.copy_(p.data)
+
+ def update_master_params(self, model: Module):
+ # Update master params from working params
+ with torch.no_grad():
+ for p in model.parameters():
+ if (p is None) or (p not in self.working_to_master_map):
+ continue
+ master_param = self.working_to_master_map[p]
+ master_param.data.copy_(p.data)
+
+ def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
+ return {id(working_p): master_p for working_p, master_p in self.working_to_master_map.items()}
+
+ def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
+ return {id(master_p): working_p for master_p, working_p in self.master_to_working_map.items()}
diff --git a/colossalai/auto_parallel/checkpoint/build_c_ext.py b/colossalai/auto_parallel/checkpoint/build_c_ext.py
index af4349865a7b..7de56f80525a 100644
--- a/colossalai/auto_parallel/checkpoint/build_c_ext.py
+++ b/colossalai/auto_parallel/checkpoint/build_c_ext.py
@@ -3,14 +3,16 @@
from setuptools import Extension, setup
this_dir = os.path.dirname(os.path.abspath(__file__))
-ext_modules = [Extension(
- 'rotorc',
- sources=[os.path.join(this_dir, 'ckpt_solver_rotor.c')],
-)]
+ext_modules = [
+ Extension(
+ "rotorc",
+ sources=[os.path.join(this_dir, "ckpt_solver_rotor.c")],
+ )
+]
setup(
- name='rotor c extension',
- version='0.1',
- description='rotor c extension for faster dp computing',
+ name="rotor c extension",
+ version="0.1",
+ description="rotor c extension for faster dp computing",
ext_modules=ext_modules,
)
diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py
index b388d00ac553..8aaa690b333c 100644
--- a/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py
+++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py
@@ -12,13 +12,13 @@
)
from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
-__all___ = ['CheckpointSolverBase']
+__all___ = ["CheckpointSolverBase"]
def _copy_output(src: Graph, dst: Graph):
"""Copy the output node from src to dst"""
for n_src, n_dst in zip(src.nodes, dst.nodes):
- if n_src.op == 'output':
+ if n_src.op == "output":
n_dst.meta = n_src.meta
@@ -28,7 +28,6 @@ def _get_param_size(module: torch.nn.Module):
class CheckpointSolverBase(ABC):
-
def __init__(
self,
graph: Graph,
@@ -81,13 +80,10 @@ def __init__(
@abstractmethod
def solve(self):
- """Solve the checkpointing problem and return the solution.
- """
- pass
+ """Solve the checkpointing problem and return the solution."""
def get_node_list(self):
- """Get the node list.
- """
+ """Get the node list."""
return [[node] for node in self.graph.nodes]
def _linearize_graph(self) -> List[List[Node]]:
@@ -140,8 +136,7 @@ def _is_sink() -> bool:
"""
def _is_inplace(n: Node):
- """Get the inplace argument from ``torch.fx.Node``
- """
+ """Get the inplace argument from ``torch.fx.Node``"""
inplace = False
if n.op == "call_function":
inplace = n.kwargs.get("inplace", False)
@@ -150,19 +145,22 @@ def _is_inplace(n: Node):
return inplace
def _is_shape_consistency(n: Node):
- """Check if this node is shape-consistency node (i.e. ``runtime_apply`` or ``runtime_apply_for_iterable_object``)
- """
+ """Check if this node is shape-consistency node (i.e. ``runtime_apply`` or ``runtime_apply_for_iterable_object``)"""
return n.target in [runtime_apply, runtime_apply_for_iterable_object, runtime_comm_spec_apply]
- return not sum([v for _, v in deps.items()]) and not any(map(_is_inplace, n.users)) and not any(
- map(_is_shape_consistency, n.users))
+ return (
+ not sum([v for _, v in deps.items()])
+ and not any(map(_is_inplace, n.users))
+ and not any(map(_is_shape_consistency, n.users))
+ )
# make sure that item in cnode is valid
if self.cnode:
for name in self.cnode:
try:
- assert next(node for node in self.graph.nodes if node.name == name).op == "placeholder", \
- f"Common node {name} is not an input of the model."
+ assert (
+ next(node for node in self.graph.nodes if node.name == name).op == "placeholder"
+ ), f"Common node {name} is not an input of the model."
except StopIteration:
raise ValueError(f"Common node name {name} not in graph.")
@@ -187,8 +185,9 @@ def _is_shape_consistency(n: Node):
region = []
# propagate common node attr if possible
- if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.cnode
- ]) or _is_cop(n.target):
+ if len(n.all_input_nodes) == len(
+ [node for node in n.all_input_nodes if node.name in self.cnode]
+ ) or _is_cop(n.target):
self.cnode.append(n.name)
else:
deps[n] = len([user for user in n.users if user.op != "output"])
diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py
index 19b2ef5987c9..ab16cc04b730 100644
--- a/colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py
+++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py
@@ -8,11 +8,10 @@
from .ckpt_solver_base import CheckpointSolverBase
-__all__ = ['CheckpointSolverChen']
+__all__ = ["CheckpointSolverChen"]
class CheckpointSolverChen(CheckpointSolverBase):
-
def __init__(self, graph: Graph, cnode: List[str] = None, num_grids: int = 6):
"""
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
@@ -40,14 +39,14 @@ def solve(self) -> Graph:
Returns:
graph (Graph): The optimized graph, should be a copy of the original graph.
"""
- checkpointable_op = ['call_module', 'call_method', 'call_function', 'get_attr']
+ checkpointable_op = ["call_module", "call_method", "call_function", "get_attr"]
ckpt = self.grid_search()
for i, seg in enumerate(ckpt):
for idx in range(*seg):
nodes = self.node_list[idx]
for n in nodes:
if n.op in checkpointable_op:
- n.meta['activation_checkpoint'] = i
+ n.meta["activation_checkpoint"] = i
return deepcopy(self.graph)
def run_chen_greedy(self, b: int = 0) -> Tuple[Set, int]:
diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py
index 21c3bf0da758..d10c41ae2b96 100644
--- a/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py
+++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py
@@ -1,5 +1,5 @@
from copy import deepcopy
-from typing import Any, Dict, List, Tuple
+from typing import Any, List, Tuple
from torch import Tensor
from torch.fx import Graph, Node
@@ -18,17 +18,18 @@
from .ckpt_solver_base import CheckpointSolverBase
from .operation import Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Loss, Sequence
-__all__ = ['CheckpointSolverRotor']
+__all__ = ["CheckpointSolverRotor"]
class CheckpointSolverRotor(CheckpointSolverBase):
-
- def __init__(self,
- graph: Graph,
- free_memory: float = -1,
- cnode: List[str] = None,
- memory_slots: int = 500,
- optim_multiplier: float = 1.0):
+ def __init__(
+ self,
+ graph: Graph,
+ free_memory: float = -1,
+ cnode: List[str] = None,
+ memory_slots: int = 500,
+ optim_multiplier: float = 1.0,
+ ):
"""This is the simple implementation of dynamic programming algorithm rotor
in https://hal.inria.fr/hal-02352969. Some code are adapted from
https://gitlab.inria.fr/hiepacs/rotor.
@@ -85,13 +86,14 @@ def solve(self, force_python: bool = False, verbose: bool = False) -> Graph:
# backtrack
try:
- self.sequence = self._backtrack(chain, 0, len(chain), self.memory_slots - chain.x[0], self.cost_table,
- self.back_ptr)
+ self.sequence = self._backtrack(
+ chain, 0, len(chain), self.memory_slots - chain.x[0], self.cost_table, self.back_ptr
+ )
self._annotate_from_sequence(self.sequence, self.node_list)
except ValueError as e:
# using logger to annonce that the solver is failed
logger = get_dist_logger()
- logger.warning(f'Checkpoint solver failed: {e}')
+ logger.warning(f"Checkpoint solver failed: {e}")
raise ValueError
if verbose:
@@ -100,14 +102,19 @@ def solve(self, force_python: bool = False, verbose: bool = False) -> Graph:
return deepcopy(self.graph)
def print_chain(self):
- print('[input]', self.chain.x[0], self.chain.xbar[0], self.chain.ftmp[0], self.chain.btmp[0])
+ print("[input]", self.chain.x[0], self.chain.xbar[0], self.chain.ftmp[0], self.chain.btmp[0])
for idx in range(len(self.node_list) - 1):
- print(self.node_list[idx], self.chain.x[idx + 1], self.chain.xbar[idx + 1], self.chain.ftmp[idx],
- self.chain.btmp[idx])
- print(f'Chain = {self.chain}')
+ print(
+ self.node_list[idx],
+ self.chain.x[idx + 1],
+ self.chain.xbar[idx + 1],
+ self.chain.ftmp[idx],
+ self.chain.btmp[idx],
+ )
+ print(f"Chain = {self.chain}")
def print_sequence(self):
- print(f'Sequence = {self.sequence}')
+ print(f"Sequence = {self.sequence}")
@classmethod
def _construct_chain(cls, graph: Graph, node_list: List[List[Node]]) -> Chain:
@@ -138,14 +145,14 @@ def _extract_node_info(cls, node: List[Node]) -> Tuple[int, ...]:
btime = 0
fwd_mem_peak = 0
for n in node:
- assert isinstance(n, Node), f'{n} is not a Node'
+ assert isinstance(n, Node), f"{n} is not a Node"
if n.target == runtime_apply or n.target == runtime_comm_spec_apply:
# in this case we need to calculate memory usage directly based on the statics that hooked in node.meta
- xbar += n.meta['fwd_mem_out']
- fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'])
+ xbar += n.meta["fwd_mem_out"]
+ fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta["fwd_mem_tmp"])
else:
xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n)
- fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'] + cls._extract_unused_output(n))
+ fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta["fwd_mem_tmp"] + cls._extract_unused_output(n))
# minimum flop count is required
ftime += max(calculate_fwd_time(n), 1.0)
@@ -162,14 +169,14 @@ def _extract_input(graph: Graph) -> Tuple[Tensor, ...]:
"""Extract input tensors from a Graph"""
input_tensors = []
for node in graph.nodes:
- if node.op == 'placeholder':
- input_tensors.append(node.meta['fwd_out'])
+ if node.op == "placeholder":
+ input_tensors.append(node.meta["fwd_out"])
return input_tensors
@staticmethod
def _extract_unused_output(node: Node) -> int:
"""Extract unused output from `torch.fx.Node`"""
- return activation_size(node.meta['fwd_out']) - calculate_fwd_out(node)
+ return activation_size(node.meta["fwd_out"]) - calculate_fwd_out(node)
@staticmethod
def _extract_btmp(node: List[Node]) -> int:
@@ -180,8 +187,8 @@ def _extract_deps_size():
for k, v in deps.items():
k: Node
if v > 0:
- deps_size += k.meta['bwd_mem_out']
- if v == float('-inf'):
+ deps_size += k.meta["bwd_mem_out"]
+ if v == float("-inf"):
deps_size -= calculate_fwd_tmp(k) + calculate_fwd_out(k)
return deps_size
@@ -190,12 +197,12 @@ def _extract_deps_size():
deps = {}
for n in reversed(node):
deps[n] = len(n.all_input_nodes)
- btmp = max(btmp, _extract_deps_size() + n.meta['bwd_mem_tmp'])
+ btmp = max(btmp, _extract_deps_size() + n.meta["bwd_mem_tmp"])
for child in n.users:
if child in deps:
deps[child] -= 1
if deps[child] <= 0:
- deps[child] = float('-inf') # free
+ deps[child] = float("-inf") # free
return btmp
@staticmethod
@@ -244,10 +251,11 @@ def _compute_table(chain: Chain, mmax: int) -> Tuple:
if m < mmin:
cost_table[m][i][idx] = float("inf")
else:
- leaf_checkpoints = [(j,
- sum(ftime[i:j]) + cost_table[m - x[j]][j][idx] + cost_table[m][i][j - 1])
- for j in range(i + 1, idx + 1)
- if m >= x[j]]
+ leaf_checkpoints = [
+ (j, sum(ftime[i:j]) + cost_table[m - x[j]][j][idx] + cost_table[m][i][j - 1])
+ for j in range(i + 1, idx + 1)
+ if m >= x[j]
+ ]
if leaf_checkpoints:
best_leaf = min(leaf_checkpoints, key=lambda t: t[1])
else:
@@ -274,13 +282,16 @@ def _compute_table_c(chain: Chain, mmax: int) -> Tuple:
import os
import subprocess
import sys
+
logger = get_dist_logger()
logger.info("rotorc hasn't been built! Building library...", ranks=[0])
this_dir = os.path.dirname(os.path.abspath(__file__))
result = subprocess.Popen(
[
- f"{sys.executable}", f"{os.path.join(this_dir, 'build_c_ext.py')}", "build_ext",
- f"--build-lib={this_dir}"
+ f"{sys.executable}",
+ f"{os.path.join(this_dir, 'build_c_ext.py')}",
+ "build_ext",
+ f"--build-lib={this_dir}",
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
@@ -294,8 +305,9 @@ def _compute_table_c(chain: Chain, mmax: int) -> Tuple:
return compute_table(chain, mmax)
@staticmethod
- def _backtrack(chain: Chain, lhs: int, rhs: int, budget: int, cost_table: List[Any],
- back_ptr: List[Any]) -> "Sequence":
+ def _backtrack(
+ chain: Chain, lhs: int, rhs: int, budget: int, cost_table: List[Any], back_ptr: List[Any]
+ ) -> "Sequence":
"""Backtrack the cost table and retrieve the optimal checkpointing strategy.
Args:
@@ -328,8 +340,9 @@ def _backtrack(chain: Chain, lhs: int, rhs: int, budget: int, cost_table: List[A
if back_ptr[budget][lhs][rhs][0]:
sequence += [
ForwardEnable(lhs),
- CheckpointSolverRotor._backtrack(chain, lhs + 1, rhs, budget - chain.xbar[lhs + 1], cost_table,
- back_ptr),
+ CheckpointSolverRotor._backtrack(
+ chain, lhs + 1, rhs, budget - chain.xbar[lhs + 1], cost_table, back_ptr
+ ),
Backward(lhs),
]
else:
@@ -337,8 +350,9 @@ def _backtrack(chain: Chain, lhs: int, rhs: int, budget: int, cost_table: List[A
sequence += [ForwardCheck(lhs)]
sequence += [ForwardNograd(k) for k in range(lhs + 1, best_leaf)]
sequence += [
- CheckpointSolverRotor._backtrack(chain, best_leaf, rhs, budget - chain.x[best_leaf], cost_table,
- back_ptr),
+ CheckpointSolverRotor._backtrack(
+ chain, best_leaf, rhs, budget - chain.x[best_leaf], cost_table, back_ptr
+ ),
CheckpointSolverRotor._backtrack(chain, lhs, best_leaf - 1, budget, cost_table, back_ptr),
]
return sequence
@@ -353,8 +367,8 @@ def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]):
"""
op_list = sequence.list_operations()
loss_op = next(op for op in op_list if isinstance(op, Loss))
- fwd_list = op_list[:op_list.index(loss_op)]
- bwd_list = op_list[op_list.index(loss_op) + 1:]
+ fwd_list = op_list[: op_list.index(loss_op)]
+ bwd_list = op_list[op_list.index(loss_op) + 1 :]
ckpt_idx = 0
in_ckpt = False
ckpt_region = []
@@ -369,7 +383,7 @@ def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]):
in_ckpt = False
for node_idx in ckpt_region:
for n in node_list[node_idx]:
- n.meta['activation_checkpoint'] = [ckpt_idx]
+ n.meta["activation_checkpoint"] = [ckpt_idx]
ckpt_idx += 1
ckpt_region = []
@@ -377,7 +391,7 @@ def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]):
elif isinstance(op, ForwardCheck):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
- n.meta['activation_checkpoint'] = [ckpt_idx]
+ n.meta["activation_checkpoint"] = [ckpt_idx]
ckpt_idx += 1
ckpt_region = [idx]
@@ -397,7 +411,7 @@ def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]):
elif isinstance(op, ForwardEnable):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
- n.meta['activation_checkpoint'].append(ckpt_idx)
+ n.meta["activation_checkpoint"].append(ckpt_idx)
ckpt_idx += 1
ckpt_region = []
@@ -405,7 +419,7 @@ def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]):
elif isinstance(op, ForwardCheck):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
- n.meta['activation_checkpoint'].append(ckpt_idx)
+ n.meta["activation_checkpoint"].append(ckpt_idx)
ckpt_idx += 1
ckpt_region = [op.index]
@@ -413,7 +427,7 @@ def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]):
elif isinstance(op, Backward):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
- n.meta['activation_checkpoint'].append(ckpt_idx)
+ n.meta["activation_checkpoint"].append(ckpt_idx)
in_recompute = False
@@ -431,9 +445,11 @@ def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]):
for node in node_list:
op_list += node
ckpt_regions = _find_nested_ckpt_regions(op_list)
- for (start_idx, end_idx) in ckpt_regions:
+ for start_idx, end_idx in ckpt_regions:
nested_length = max(
- len(op_list[idx].meta['activation_checkpoint']) for idx in range(start_idx, end_idx + 1))
+ len(op_list[idx].meta["activation_checkpoint"]) for idx in range(start_idx, end_idx + 1)
+ )
for idx in range(start_idx, end_idx + 1):
- op_list[idx].meta['activation_checkpoint'] += [None] * (nested_length -
- len(op_list[idx].meta['activation_checkpoint']))
+ op_list[idx].meta["activation_checkpoint"] += [None] * (
+ nested_length - len(op_list[idx].meta["activation_checkpoint"])
+ )
diff --git a/colossalai/auto_parallel/checkpoint/operation.py b/colossalai/auto_parallel/checkpoint/operation.py
index ab0c6c5ad38d..5f8077916433 100644
--- a/colossalai/auto_parallel/checkpoint/operation.py
+++ b/colossalai/auto_parallel/checkpoint/operation.py
@@ -1,20 +1,21 @@
import math
from abc import ABC
-from typing import Any, Iterable, List
+from typing import List
from torch.utils._pytree import tree_map
class Chain:
-
- def __init__(self,
- ftime: List[float],
- btime: List[float],
- x: List[int],
- xbar: List[int],
- ftmp: List[int],
- btmp: List[int],
- check_consistency: bool = True):
+ def __init__(
+ self,
+ ftime: List[float],
+ btime: List[float],
+ x: List[int],
+ xbar: List[int],
+ ftmp: List[int],
+ btmp: List[int],
+ check_consistency: bool = True,
+ ):
"""The chain is a basic linearized structure for solving the dynamic programming problem for activation checkpoint.
See paper https://hal.inria.fr/hal-02352969 for details.
@@ -37,9 +38,14 @@ def __init__(self,
raise AttributeError("In Chain, input lists do not have consistent lengths")
def check_lengths(self):
- return ((len(self.ftime) == len(self)) and (len(self.btime) == len(self) + 1) and (len(self.x) == len(self) + 1)
- and (len(self.ftmp) == len(self)) and (len(self.btmp) == len(self) + 1)
- and (len(self.xbar) == len(self) + 1))
+ return (
+ (len(self.ftime) == len(self))
+ and (len(self.btime) == len(self) + 1)
+ and (len(self.x) == len(self) + 1)
+ and (len(self.ftmp) == len(self))
+ and (len(self.btmp) == len(self) + 1)
+ and (len(self.xbar) == len(self) + 1)
+ )
def __repr__(self):
chain_list = []
@@ -100,7 +106,6 @@ class ForwardCheck(Forward):
class Forwards(Operation):
-
def __init__(self, start, end):
self.index = (start, end)
@@ -109,9 +114,9 @@ def __repr__(self):
def cost(self, chain: Chain):
if chain is not None:
- return sum(chain.ftime[self.index[0]:self.index[1] + 1])
+ return sum(chain.ftime[self.index[0] : self.index[1] + 1])
else:
- return (self.index[1] - self.index[0] + 1)
+ return self.index[1] - self.index[0] + 1
def isForward(op):
@@ -132,7 +137,6 @@ def cost(self, chain: Chain):
class Loss(Operation):
-
def __init__(self):
pass
@@ -166,7 +170,6 @@ class DiscardMemory(MemoryAccess):
class Sequence(list):
-
def __init__(self):
super().__init__()
diff --git a/colossalai/auto_parallel/meta_profiler/constants.py b/colossalai/auto_parallel/meta_profiler/constants.py
index 35b8c13ee8ff..2f638fa919e4 100644
--- a/colossalai/auto_parallel/meta_profiler/constants.py
+++ b/colossalai/auto_parallel/meta_profiler/constants.py
@@ -3,8 +3,6 @@
import torch
import torch.nn as nn
-from ..tensor_shard.constants import *
-
# list of inplace module
INPLACE_MODULE = [nn.ReLU]
diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py b/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py
index 0f2e9e44f91c..4234481ae2ca 100644
--- a/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py
+++ b/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py
@@ -25,28 +25,32 @@ def elementwise_meta_info(temp_mem_scale: float = 0, buffer_mem_scale: float = 0
def meta_func(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
input_tensor = next(
filter(
- lambda x:
- (x.type == OperationDataType.ARG or x.type == OperationDataType.PARAM) and x.name != 'softmax_dim',
- args)).data
+ lambda x: (x.type == OperationDataType.ARG or x.type == OperationDataType.PARAM)
+ and x.name != "softmax_dim",
+ args,
+ )
+ ).data
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
- is_inplace = 1 if kwargs.get('inplace', False) else 0
+ is_inplace = 1 if kwargs.get("inplace", False) else 0
flop_counter = elementwise_flop_counter(1, 0)
# calculate compute cost
fwd_compute_cost = flop_counter([input_tensor], [output_tensor])
bwd_compute_cost = flop_counter([output_tensor], [input_tensor])
- compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
- bwd=bwd_compute_cost,
- total=fwd_compute_cost + bwd_compute_cost)
+ compute_cost = TrainCycleItem(
+ fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
+ )
# calculate memory cost
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
# NOTE: if in_place is True, we will not create a new tensor in forward
- fwd_memory_cost = MemoryCost(activation=activation_size(input_tensor) * (2 - is_inplace),
- parameter=0,
- temp=0,
- buffer=activation_size(input_tensor) * buffer_mem_scale)
+ fwd_memory_cost = MemoryCost(
+ activation=activation_size(input_tensor) * (2 - is_inplace),
+ parameter=0,
+ temp=0,
+ buffer=activation_size(input_tensor) * buffer_mem_scale,
+ )
# temp_mem_scale is for situation like softmax backward
# the buffer will be removed during backward phase
@@ -54,20 +58,23 @@ def meta_func(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[tor
activation=activation_size(input_tensor) - activation_size(input_tensor) * buffer_mem_scale,
parameter=0,
temp=activation_size(input_tensor) * temp_mem_scale + activation_size(input_tensor) * buffer_mem_scale,
- buffer=0)
+ buffer=0,
+ )
# total cost is the sum of forward and backward cost
- total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
- parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
- temp=fwd_memory_cost.temp + bwd_memory_cost.temp,
- buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer)
+ total_cost = MemoryCost(
+ activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
+ parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
+ temp=fwd_memory_cost.temp + bwd_memory_cost.temp,
+ buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer,
+ )
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = []
- fwd_buffer = [torch.zeros_like(output_tensor, device='meta')]
- fwd_out = [torch.zeros_like(output_tensor, device='meta')]
+ fwd_buffer = [torch.zeros_like(output_tensor, device="meta")]
+ fwd_out = [torch.zeros_like(output_tensor, device="meta")]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py b/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py
index e451748512b9..0b7b51a71955 100644
--- a/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py
+++ b/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py
@@ -6,10 +6,10 @@
from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
-from ..constants import BCAST_FUNC_OP, NO_SAVE_ACTIVATION
+from ..constants import BCAST_FUNC_OP
from ..registry import meta_register
-__all__ = ['binary_elementwise_meta_info']
+__all__ = ["binary_elementwise_meta_info"]
@meta_register.register(BCAST_FUNC_OP)
@@ -61,6 +61,6 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train
# store fwd_in, fwd_buffer, fwd_out
fwd_in = []
fwd_buffer = []
- fwd_out = [torch.zeros_like(output_op_data.data, device='meta')]
+ fwd_out = [torch.zeros_like(output_op_data.data, device="meta")]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py b/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py
index 4336bf68363c..2f630995cdbc 100644
--- a/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py
+++ b/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py
@@ -1,22 +1,14 @@
-from typing import Callable, Dict, List, Tuple, Union
+from typing import List, Tuple
import torch
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
-from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
- MemoryCost,
- OperationData,
- OperationDataType,
- ShardingStrategy,
- StrategiesVector,
- TrainCycleItem,
-)
-from colossalai.tensor.sharding_spec import ShardingSpec
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from ..registry import meta_register
-__all__ = ['convnd_meta_info']
+__all__ = ["convnd_meta_info"]
@meta_register.register(torch.nn.Conv1d)
@@ -103,35 +95,47 @@ def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# calculate compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten.convolution.default](fwd_args, (output_tensor,))
- bwd_compute_cost = flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor, bias_tensor)) if has_bias else \
- flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor))
+ bwd_compute_cost = (
+ flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor, bias_tensor))
+ if has_bias
+ else flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor))
+ )
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
# calculate memory cost
# TODO: use profiler to check conv temp memory
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
- fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
- parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
- if has_bias else compute_size_in_bytes(weight_tensor),
- temp=0,
- buffer=0)
-
- bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor])
- if has_bias else compute_size_in_bytes([input_tensor, weight_tensor]),
- parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
- if has_bias else compute_size_in_bytes(weight_tensor),
- temp=0,
- buffer=0)
+ fwd_memory_cost = MemoryCost(
+ activation=compute_size_in_bytes([input_tensor, output_tensor]),
+ parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
+ if has_bias
+ else compute_size_in_bytes(weight_tensor),
+ temp=0,
+ buffer=0,
+ )
+
+ bwd_memory_cost = MemoryCost(
+ activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor])
+ if has_bias
+ else compute_size_in_bytes([input_tensor, weight_tensor]),
+ parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
+ if has_bias
+ else compute_size_in_bytes(weight_tensor),
+ temp=0,
+ buffer=0,
+ )
# total cost is the sum of forward and backward cost
- total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
- parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
+ total_cost = MemoryCost(
+ activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
+ parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
+ )
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in, fwd_buffer, fwd_out
- fwd_in = [torch.zeros_like(input_tensor, device='meta')]
+ fwd_in = [torch.zeros_like(input_tensor, device="meta")]
fwd_buffer = []
- fwd_out = [torch.zeros_like(output_tensor, device='meta')]
+ fwd_out = [torch.zeros_like(output_tensor, device="meta")]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/embedding.py b/colossalai/auto_parallel/meta_profiler/meta_registry/embedding.py
index d5d80f5b3700..7c9add810fd8 100644
--- a/colossalai/auto_parallel/meta_profiler/meta_registry/embedding.py
+++ b/colossalai/auto_parallel/meta_profiler/meta_registry/embedding.py
@@ -24,8 +24,9 @@ def embedding_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem
# compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten.embedding.default]([weight_tensor, input_tensor], [output_tensor])
- bwd_compute_cost = flop_mapping[torch.ops.aten.embedding_dense_backward.default]([output_tensor, weight_tensor],
- [weight_tensor])
+ bwd_compute_cost = flop_mapping[torch.ops.aten.embedding_dense_backward.default](
+ [output_tensor, weight_tensor], [weight_tensor]
+ )
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
@@ -34,10 +35,9 @@ def embedding_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem
# NOTE: during the backward phase of torch.nn.Embedding, it seems when the input is large enough, it will
# have a temp memory which is kind of weird and we don't know the reason yet, so currently we just assume
# that there will be no temp memory, as the temp memory is significantly smaller than the gradient memory
- fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
- parameter=0,
- temp=0,
- buffer=0)
+ fwd_memory_cost = MemoryCost(
+ activation=compute_size_in_bytes([input_tensor, output_tensor]), parameter=0, temp=0, buffer=0
+ )
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([weight_tensor]), parameter=0, temp=0, buffer=0)
total_memory_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation)
diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py b/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py
index 94dd9143e0ae..d731f9cb4436 100644
--- a/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py
+++ b/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py
@@ -1,23 +1,15 @@
from functools import reduce
-from typing import Callable, Dict, List, Tuple, Union
+from typing import List, Tuple
import torch
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
-from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
- MemoryCost,
- OperationData,
- OperationDataType,
- ShardingStrategy,
- StrategiesVector,
- TrainCycleItem,
-)
-from colossalai.tensor.sharding_spec import ShardingSpec
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
from ..registry import meta_register
-__all__ = ['linear_meta_info', 'matmul_meta_info']
+__all__ = ["linear_meta_info", "matmul_meta_info"]
@meta_register.register(torch.nn.functional.linear)
@@ -100,32 +92,43 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# calculate compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten.addmm.default](
- [bias_tensor, input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,))
- bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,)) + \
- flop_mapping[torch.ops.aten.mm.default]([torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,)) + \
- flop_mapping[torch.ops.aten.sum.dim_IntList]([output_tensor], (bias_tensor,))
- compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
- bwd=bwd_compute_cost,
- total=fwd_compute_cost + bwd_compute_cost)
+ [bias_tensor, input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,)
+ )
+ bwd_compute_cost = (
+ flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,))
+ + flop_mapping[torch.ops.aten.mm.default](
+ [torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,)
+ )
+ + flop_mapping[torch.ops.aten.sum.dim_IntList]([output_tensor], (bias_tensor,))
+ )
+ compute_cost = TrainCycleItem(
+ fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
+ )
# calculate memory cost
# NOTE: Linear don't have buffer and temp in forward and backward phase
# the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor and bias_tensor
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
- fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
- parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
- temp=0,
- buffer=0)
+ fwd_memory_cost = MemoryCost(
+ activation=compute_size_in_bytes([input_tensor, output_tensor]),
+ parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
+ temp=0,
+ buffer=0,
+ )
# the backward activation cost is the size of input_tensor, weight_tensor and bias_tensor, parameter cost is 0
- bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]),
- parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
- temp=0,
- buffer=0)
+ bwd_memory_cost = MemoryCost(
+ activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]),
+ parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
+ temp=0,
+ buffer=0,
+ )
# total cost is to sum the forward and backward cost
- total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
- parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
+ total_cost = MemoryCost(
+ activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
+ parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
+ )
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
@@ -136,39 +139,49 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# calculate compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
- [input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,))
- bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,)) + \
- flop_mapping[torch.ops.aten.mm.default]([torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,))
+ [input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,)
+ )
+ bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
+ [output_tensor, weight_tensor], (input_tensor,)
+ ) + flop_mapping[torch.ops.aten.mm.default](
+ [torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,)
+ )
- compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
- bwd=bwd_compute_cost,
- total=fwd_compute_cost + bwd_compute_cost)
+ compute_cost = TrainCycleItem(
+ fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
+ )
# calculate memory cost
# NOTE: Linear don't have buffer and temp in forward and backward phase
# the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
- fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
- parameter=compute_size_in_bytes(weight_tensor),
- temp=0,
- buffer=0)
+ fwd_memory_cost = MemoryCost(
+ activation=compute_size_in_bytes([input_tensor, output_tensor]),
+ parameter=compute_size_in_bytes(weight_tensor),
+ temp=0,
+ buffer=0,
+ )
# the backward activation cost is the size of input_tensor and weight_tensor, parameter cost is 0
- bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor]),
- parameter=compute_size_in_bytes(weight_tensor),
- temp=0,
- buffer=0)
+ bwd_memory_cost = MemoryCost(
+ activation=compute_size_in_bytes([input_tensor, weight_tensor]),
+ parameter=compute_size_in_bytes(weight_tensor),
+ temp=0,
+ buffer=0,
+ )
# total cost is to sum the forward and backward cost
- total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
- parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
+ total_cost = MemoryCost(
+ activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
+ parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
+ )
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in, fwd_buffer, fwd_out
- fwd_in = [torch.zeros_like(input_tensor, device='meta')]
+ fwd_in = [torch.zeros_like(input_tensor, device="meta")]
fwd_buffer = []
- fwd_out = [torch.zeros_like(output_tensor, device='meta')]
+ fwd_out = [torch.zeros_like(output_tensor, device="meta")]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
@@ -222,15 +235,16 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# batched gemv case 1: batched matrix-vector multiplication
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](
- [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]], output_tensors)
+ [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]], output_tensors
+ )
# combine the dimensions of output
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](
- [output_tensors[0].reshape(-1), input_tensors[1]],
- output_tensors) + \
- flop_mapping[torch.ops.aten.matmul.default](
- [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1)],
- output_tensors)
+ [output_tensors[0].reshape(-1), input_tensors[1]], output_tensors
+ ) + flop_mapping[torch.ops.aten.matmul.default](
+ [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1)],
+ output_tensors,
+ )
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0)
@@ -239,86 +253,104 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# gemv case 2: vector-matrix multiplication
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](input_tensors, output_tensors)
- bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor]([output_tensors[0], input_tensors[0]], output_tensors) + \
- flop_mapping[torch.ops.aten.matmul.default]([input_tensors[1], output_tensors[0]], output_tensors)
+ bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](
+ [output_tensors[0], input_tensors[0]], output_tensors
+ ) + flop_mapping[torch.ops.aten.matmul.default]([input_tensors[1], output_tensors[0]], output_tensors)
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
- bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors),
- parameter=0,
- temp=compute_size_in_bytes(input_tensors[1]),
- buffer=0)
+ bwd_mem_cost = MemoryCost(
+ activation=compute_size_in_bytes(input_tensors),
+ parameter=0,
+ temp=compute_size_in_bytes(input_tensors[1]),
+ buffer=0,
+ )
elif len(input_tensors[0].shape) == 1 and len(input_tensors[1].shape) >= 3:
# batched gemv case 2: vector-batched matrix multiplication
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0]],
- [output_tensors[0].reshape(-1)])
+ [output_tensors[0].reshape(-1)],
+ )
# combine the dimensions of output
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](
- [output_tensors[0].reshape(-1), input_tensors[0]],
- output_tensors
- ) + \
- flop_mapping[torch.ops.aten.matmul.default](
- [input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]).transpose(0, 1), output_tensors[0].reshape(-1)],
- output_tensors
- )
+ [output_tensors[0].reshape(-1), input_tensors[0]], output_tensors
+ ) + flop_mapping[torch.ops.aten.matmul.default](
+ [
+ input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]).transpose(0, 1),
+ output_tensors[0].reshape(-1),
+ ],
+ output_tensors,
+ )
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors + [input_tensors[1]]))
- bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors[0]),
- parameter=0,
- temp=compute_size_in_bytes(input_tensors[1]),
- buffer=0)
+ bwd_mem_cost = MemoryCost(
+ activation=compute_size_in_bytes(input_tensors[0]),
+ parameter=0,
+ temp=compute_size_in_bytes(input_tensors[1]),
+ buffer=0,
+ )
elif len(input_tensors[0].shape) >= 2 and len(input_tensors[1].shape) == 2:
# gemm & batched gemm case 1: batched matrix-matrix multiplication
fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]],
- [output_tensors[0].reshape(-1, output_tensors[0].shape[-1])])
+ [output_tensors[0].reshape(-1, output_tensors[0].shape[-1])],
+ )
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
- [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1, output_tensors[0].shape[-1])],
- [input_tensors[1]]
- ) + \
- flop_mapping[torch.ops.aten.mm.default](
- [output_tensors[0].reshape(-1, output_tensors[0].shape[-1]), input_tensors[1].transpose(0, 1)],
- [input_tensors[0].reshape(-1, input_tensors[0].shape[-1])]
- )
+ [
+ input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1),
+ output_tensors[0].reshape(-1, output_tensors[0].shape[-1]),
+ ],
+ [input_tensors[1]],
+ ) + flop_mapping[torch.ops.aten.mm.default](
+ [output_tensors[0].reshape(-1, output_tensors[0].shape[-1]), input_tensors[1].transpose(0, 1)],
+ [input_tensors[0].reshape(-1, input_tensors[0].shape[-1])],
+ )
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0)
elif len(input_tensors[0].shape) == 2 and len(input_tensors[1].shape) >= 3:
# batched gemm case 2: matrix-batched matrix multiplication
- fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([
- input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0].transpose(
- 0, 1)
- ], [output_tensors[0].transpose(-2, -1)])
+ fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
+ [
+ input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]),
+ input_tensors[0].transpose(0, 1),
+ ],
+ [output_tensors[0].transpose(-2, -1)],
+ )
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
- [output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]).transpose(0, 1), input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])],
- [input_tensors[0]]
- ) + \
- flop_mapping[torch.ops.aten.mm.default](
- [output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]), input_tensors[0]],
- [input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])]
- )
-
- fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors) +
- compute_size_in_bytes(input_tensors[1]),
- temp=compute_size_in_bytes(output_tensors))
- bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors[0]),
- parameter=0,
- temp=compute_size_in_bytes(input_tensors[1]) + compute_size_in_bytes(output_tensors))
+ [
+ output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]).transpose(0, 1),
+ input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]),
+ ],
+ [input_tensors[0]],
+ ) + flop_mapping[torch.ops.aten.mm.default](
+ [output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]), input_tensors[0]],
+ [input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])],
+ )
+
+ fwd_mem_cost = MemoryCost(
+ activation=compute_size_in_bytes(output_tensors) + compute_size_in_bytes(input_tensors[1]),
+ temp=compute_size_in_bytes(output_tensors),
+ )
+ bwd_mem_cost = MemoryCost(
+ activation=compute_size_in_bytes(input_tensors[0]),
+ parameter=0,
+ temp=compute_size_in_bytes(input_tensors[1]) + compute_size_in_bytes(output_tensors),
+ )
elif all(len(tensor.shape) >= 3 for tensor in input_tensors):
# Batched matrix-batched matrix multiplication
# Fetch shape of the two inputs and see if the batch dimensions are the same
_is_batch_dims_same = True
if len(input_tensors[0].shape) == len(input_tensors[1].shape):
- for (shape_0, shape_1) in zip(input_tensors[0].shape[:-2], input_tensors[1].shape[:-2]):
+ for shape_0, shape_1 in zip(input_tensors[0].shape[:-2], input_tensors[1].shape[:-2]):
if shape_0 != shape_1:
_is_batch_dims_same = False
break
@@ -337,20 +369,28 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# Case 1: batch dimensions are the same
# Forward compute cost: C = A * B
- fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default]([
- input_tensors[0].reshape(-1, input_dim_00, input_dim_01), input_tensors[1].reshape(
- -1, input_dim_10, input_dim_11)
- ], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)])
+ fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
+ [
+ input_tensors[0].reshape(-1, input_dim_00, input_dim_01),
+ input_tensors[1].reshape(-1, input_dim_10, input_dim_11),
+ ],
+ [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)],
+ )
# Backward compute cost: dB = A^T * dC, dA = dC * B^T
bwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
- [input_tensors[0].transpose(-2, -1).reshape(-1, input_dim_01, input_dim_00), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)],
- [input_tensors[1].reshape(-1, input_dim_11, input_dim_10)]
- ) + \
- flop_mapping[torch.ops.aten.bmm.default](
- [output_tensors[0].reshape(-1, output_dim_0, output_dim_1), input_tensors[1].transpose(-2, -1).reshape(-1, input_dim_11, input_dim_10)],
- [input_tensors[0].reshape(-1, input_dim_00, input_dim_01)]
- )
+ [
+ input_tensors[0].transpose(-2, -1).reshape(-1, input_dim_01, input_dim_00),
+ output_tensors[0].reshape(-1, output_dim_0, output_dim_1),
+ ],
+ [input_tensors[1].reshape(-1, input_dim_11, input_dim_10)],
+ ) + flop_mapping[torch.ops.aten.bmm.default](
+ [
+ output_tensors[0].reshape(-1, output_dim_0, output_dim_1),
+ input_tensors[1].transpose(-2, -1).reshape(-1, input_dim_11, input_dim_10),
+ ],
+ [input_tensors[0].reshape(-1, input_dim_00, input_dim_01)],
+ )
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors))
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors))
@@ -358,43 +398,46 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
else:
# Case 2: batch dimensions are different
batch_dims = output_tensors[0].shape[:-2]
- extended_input_0 = torch.rand(reduce(lambda x, y: x * y, batch_dims),
- input_dim_00,
- input_dim_01,
- device="meta")
- extended_input_1 = torch.rand(reduce(lambda x, y: x * y, batch_dims),
- input_dim_10,
- input_dim_11,
- device="meta")
+ extended_input_0 = torch.rand(
+ reduce(lambda x, y: x * y, batch_dims), input_dim_00, input_dim_01, device="meta"
+ )
+ extended_input_1 = torch.rand(
+ reduce(lambda x, y: x * y, batch_dims), input_dim_10, input_dim_11, device="meta"
+ )
# Forward compute cost: C = A * B
fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
- [extended_input_0, extended_input_1], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)])
+ [extended_input_0, extended_input_1], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)]
+ )
# Backward compute cost: dB = A^T * dC, dA = dC * B^T
bwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
- [extended_input_0.transpose(-2, -1), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)],
- [extended_input_1]
- ) + \
- flop_mapping[torch.ops.aten.bmm.default](
- [output_tensors[0].reshape(-1, output_dim_0, output_dim_1), extended_input_1.transpose(-2, -1)],
- [extended_input_0]
- )
+ [extended_input_0.transpose(-2, -1), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)],
+ [extended_input_1],
+ ) + flop_mapping[torch.ops.aten.bmm.default](
+ [output_tensors[0].reshape(-1, output_dim_0, output_dim_1), extended_input_1.transpose(-2, -1)],
+ [extended_input_0],
+ )
fwd_mem_cost = MemoryCost(
- activation=compute_size_in_bytes([output_tensors[0], extended_input_0, extended_input_1]))
- bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors) -
- compute_size_in_bytes([extended_input_0, extended_input_1]),
- temp=compute_size_in_bytes([extended_input_0, extended_input_1]))
+ activation=compute_size_in_bytes([output_tensors[0], extended_input_0, extended_input_1])
+ )
+ bwd_mem_cost = MemoryCost(
+ activation=compute_size_in_bytes(input_tensors)
+ - compute_size_in_bytes([extended_input_0, extended_input_1]),
+ temp=compute_size_in_bytes([extended_input_0, extended_input_1]),
+ )
# compute cost
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
# memory cost
- total_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
- parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
- temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
- buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer)
+ total_cost = MemoryCost(
+ activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
+ parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
+ temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
+ buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer,
+ )
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_cost)
diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py b/colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py
index 12874810b13e..b1bb1d872c35 100644
--- a/colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py
+++ b/colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py
@@ -3,7 +3,7 @@
import torch
-from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
from ..registry import meta_register
diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py b/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py
index b872fdc8bdcd..99aaa752d0a1 100644
--- a/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py
+++ b/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py
@@ -1,22 +1,14 @@
-from typing import Callable, Dict, List, Tuple, Union
+from typing import List, Tuple
import torch
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
-from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
- MemoryCost,
- OperationData,
- OperationDataType,
- ShardingStrategy,
- StrategiesVector,
- TrainCycleItem,
-)
-from colossalai.tensor.sharding_spec import ShardingSpec
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from ..registry import meta_register
-__all__ = ['batchnormnd_meta_info', 'layernorm_meta_info']
+__all__ = ["batchnormnd_meta_info", "layernorm_meta_info"]
@meta_register.register(torch.nn.BatchNorm1d)
@@ -65,7 +57,15 @@ def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleIt
# saved inv std and some other args indicating the status of the module
# the bwd outputs are input grad, weight grad and bias grad
bwd_in_args = [
- output_tensor, output_tensor, weight_tensor, mean_tensor, var_tensor, mean_tensor, var_tensor, 1e-5, num_batch
+ output_tensor,
+ output_tensor,
+ weight_tensor,
+ mean_tensor,
+ var_tensor,
+ mean_tensor,
+ var_tensor,
+ 1e-5,
+ num_batch,
]
bwd_out_args = [input_tensor, weight_tensor, bias_tensor]
@@ -77,29 +77,34 @@ def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleIt
# calculate memory cost
# the fwd activation cost is output plus saved mean and saved inv std
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
- fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes(
- [input_tensor, output_tensor, mean_tensor, var_tensor]),
- parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
- temp=0,
- buffer=compute_size_in_bytes([mean_tensor, var_tensor]))
+ fwd_memory_cost = MemoryCost(
+ activation=compute_size_in_bytes([input_tensor, output_tensor, mean_tensor, var_tensor]),
+ parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
+ temp=0,
+ buffer=compute_size_in_bytes([mean_tensor, var_tensor]),
+ )
# the bwd memory cost is quite tricky here, BatchNorm will remove saved mean
# and saved inv std during backward phase
- bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor]),
- parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
- temp=compute_size_in_bytes([mean_tensor, var_tensor]),
- buffer=compute_size_in_bytes([mean_tensor, var_tensor]))
+ bwd_memory_cost = MemoryCost(
+ activation=compute_size_in_bytes([input_tensor]),
+ parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
+ temp=compute_size_in_bytes([mean_tensor, var_tensor]),
+ buffer=compute_size_in_bytes([mean_tensor, var_tensor]),
+ )
# total cost is the sum of forward and backward cost
- total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
- parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
+ total_cost = MemoryCost(
+ activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
+ parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
+ )
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in, fwd_buffer, fwd_out
- fwd_in = [torch.zeros_like(input_tensor, device='meta')]
- fwd_buffer = [torch.zeros_like(mean_tensor, device='meta'), torch.zeros_like(var_tensor, device='meta')]
- fwd_out = [torch.zeros_like(output_tensor, device='meta')]
+ fwd_in = [torch.zeros_like(input_tensor, device="meta")]
+ fwd_buffer = [torch.zeros_like(mean_tensor, device="meta"), torch.zeros_like(var_tensor, device="meta")]
+ fwd_out = [torch.zeros_like(output_tensor, device="meta")]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
@@ -116,8 +121,8 @@ def layernorm_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
weight_tensor = next(filter(lambda x: x.name == "weight", args)).data
bias_tensor = next(filter(lambda x: x.name == "bias", args)).data
- running_mean = torch.rand(input_tensor.shape[0], 1, device='meta')
- running_var = torch.rand(input_tensor.shape[0], 1, device='meta')
+ running_mean = torch.rand(input_tensor.shape[0], 1, device="meta")
+ running_var = torch.rand(input_tensor.shape[0], 1, device="meta")
# construct args
fwd_in_args = [input_tensor, [input_tensor.shape[0]], weight_tensor]
@@ -132,27 +137,32 @@ def layernorm_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem
# memory cost
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
- fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes(
- [input_tensor, output_tensor, weight_tensor, bias_tensor]),
- parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
- temp=0,
- buffer=compute_size_in_bytes([running_mean, running_var]))
-
- bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]),
- parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
- temp=compute_size_in_bytes([running_mean, running_var]),
- buffer=compute_size_in_bytes([running_mean, running_var]))
-
- total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
- parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
- temp=fwd_memory_cost.temp + bwd_memory_cost.temp,
- buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer)
+ fwd_memory_cost = MemoryCost(
+ activation=compute_size_in_bytes([input_tensor, output_tensor, weight_tensor, bias_tensor]),
+ parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
+ temp=0,
+ buffer=compute_size_in_bytes([running_mean, running_var]),
+ )
+
+ bwd_memory_cost = MemoryCost(
+ activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]),
+ parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
+ temp=compute_size_in_bytes([running_mean, running_var]),
+ buffer=compute_size_in_bytes([running_mean, running_var]),
+ )
+
+ total_cost = MemoryCost(
+ activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
+ parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
+ temp=fwd_memory_cost.temp + bwd_memory_cost.temp,
+ buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer,
+ )
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in, fwd_buffer, fwd_out
- fwd_in = [torch.zeros_like(input_tensor, device='meta')]
- fwd_buffer = [torch.zeros_like(running_mean, device='meta'), torch.zeros_like(running_var, device='meta')]
- fwd_out = [torch.zeros_like(output_tensor, device='meta')]
+ fwd_in = [torch.zeros_like(input_tensor, device="meta")]
+ fwd_buffer = [torch.zeros_like(running_mean, device="meta"), torch.zeros_like(running_var, device="meta")]
+ fwd_out = [torch.zeros_like(output_tensor, device="meta")]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py b/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py
index d785dfcca9ba..21aa524bed08 100644
--- a/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py
+++ b/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py
@@ -63,7 +63,7 @@ def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
# store fwd_in, fwd_buffer, fwd_out
fwd_in = []
fwd_buffer = []
- fwd_out = [torch.zeros_like(output_tensor, device='meta')]
+ fwd_out = [torch.zeros_like(output_tensor, device="meta")]
return compute_cost, mem_cost, fwd_in, fwd_buffer, fwd_out
@@ -117,8 +117,10 @@ def maxpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor, index_matrix]))
# temp memory for backward is the index matrix to be discarded
- bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensor) - compute_size_in_bytes(index_matrix),
- temp=compute_size_in_bytes(index_matrix))
+ bwd_mem_cost = MemoryCost(
+ activation=compute_size_in_bytes(input_tensor) - compute_size_in_bytes(index_matrix),
+ temp=compute_size_in_bytes(index_matrix),
+ )
# total cost
total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation, temp=bwd_mem_cost.temp)
@@ -126,8 +128,8 @@ def maxpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
mem_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
# store fwd_in, fwd_buffer, fwd_out
- fwd_in = [torch.zeros_like(input_tensor, device='meta')]
- fwd_buffer = [torch.zeros_like(index_matrix, device='meta')]
- fwd_out = [torch.zeros_like(output_tensor, device='meta')]
+ fwd_in = [torch.zeros_like(input_tensor, device="meta")]
+ fwd_buffer = [torch.zeros_like(index_matrix, device="meta")]
+ fwd_out = [torch.zeros_like(output_tensor, device="meta")]
return compute_cost, mem_cost, fwd_in, fwd_buffer, fwd_out
diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/tensor.py b/colossalai/auto_parallel/meta_profiler/meta_registry/tensor.py
index 97fe3c6196f5..9a2df1bd7c87 100644
--- a/colossalai/auto_parallel/meta_profiler/meta_registry/tensor.py
+++ b/colossalai/auto_parallel/meta_profiler/meta_registry/tensor.py
@@ -2,7 +2,6 @@
import torch
-from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
@@ -37,15 +36,19 @@ def meta_func(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[tor
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(outputs) * 2, parameter=0, temp=0, buffer=0)
- bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(outputs) * bwd_mem_out_factor,
- parameter=0,
- temp=compute_size_in_bytes(outputs) * bwd_mem_tmp_factor,
- buffer=0)
+ bwd_mem_cost = MemoryCost(
+ activation=compute_size_in_bytes(outputs) * bwd_mem_out_factor,
+ parameter=0,
+ temp=compute_size_in_bytes(outputs) * bwd_mem_tmp_factor,
+ buffer=0,
+ )
- total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
- parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
- temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
- buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer)
+ total_mem_cost = MemoryCost(
+ activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
+ parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
+ temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
+ buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer,
+ )
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
@@ -66,14 +69,24 @@ def meta_func(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[tor
# register torch.Tensor related metainfo
# (0, 0)
-meta_register.register([torch.tensor, torch.Tensor.to, torch.Tensor.unsqueeze, torch.unsqueeze,
- torch.arange])(tensor_related_metainfo(0, 0))
+meta_register.register([torch.tensor, torch.Tensor.to, torch.Tensor.unsqueeze, torch.unsqueeze, torch.arange])(
+ tensor_related_metainfo(0, 0)
+)
# (1, 0)
-meta_register.register([
- torch.Tensor.flatten, torch.flatten, torch.Tensor.transpose, torch.transpose, torch.Tensor.permute, torch.permute,
- torch.Tensor.split, torch.split, torch.Tensor.view
-])(tensor_related_metainfo(1, 0))
+meta_register.register(
+ [
+ torch.Tensor.flatten,
+ torch.flatten,
+ torch.Tensor.transpose,
+ torch.transpose,
+ torch.Tensor.permute,
+ torch.permute,
+ torch.Tensor.split,
+ torch.split,
+ torch.Tensor.view,
+ ]
+)(tensor_related_metainfo(1, 0))
# (1, 1)
meta_register.register([torch.Tensor.type, torch.Tensor.contiguous])(tensor_related_metainfo(1, 1))
diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/where.py b/colossalai/auto_parallel/meta_profiler/meta_registry/where.py
index 5cba1b5b6e2b..107851b80d7c 100644
--- a/colossalai/auto_parallel/meta_profiler/meta_registry/where.py
+++ b/colossalai/auto_parallel/meta_profiler/meta_registry/where.py
@@ -4,7 +4,7 @@
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size
-from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
from ..registry import meta_register
@@ -39,16 +39,21 @@ def where_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, Li
# gradient matrix for input x and input y, remove the temp memory and condition tensor generated in forward phase
# NOTE: currently in SPMD solver we always believe that there will be a new input tensor created in forward
fwd_mem_cost = MemoryCost(activation=activation_size([condition_tensor, x_tensor, y_tensor, output_tensor]))
- bwd_mem_cost = MemoryCost(activation=activation_size([x_tensor, y_tensor]) - activation_size([condition_tensor]),
- parameter=0,
- temp=activation_size([output_tensor]) * 3 + activation_size([condition_tensor]) -
- activation_size([x_tensor, y_tensor]),
- buffer=0)
-
- total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
- parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
- temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
- buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer)
+ bwd_mem_cost = MemoryCost(
+ activation=activation_size([x_tensor, y_tensor]) - activation_size([condition_tensor]),
+ parameter=0,
+ temp=activation_size([output_tensor]) * 3
+ + activation_size([condition_tensor])
+ - activation_size([x_tensor, y_tensor]),
+ buffer=0,
+ )
+
+ total_mem_cost = MemoryCost(
+ activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
+ parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
+ temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
+ buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer,
+ )
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
diff --git a/colossalai/auto_parallel/meta_profiler/registry.py b/colossalai/auto_parallel/meta_profiler/registry.py
index 46350c4dd406..c29086f7f9d1 100644
--- a/colossalai/auto_parallel/meta_profiler/registry.py
+++ b/colossalai/auto_parallel/meta_profiler/registry.py
@@ -1,14 +1,12 @@
-__all__ = ['Registry']
+__all__ = ["Registry"]
class Registry:
-
def __init__(self, name):
self.name = name
self.store = {}
def register(self, source):
-
def wrapper(func):
if isinstance(source, (list, tuple)):
# support register a list of items for this func
@@ -21,7 +19,7 @@ def wrapper(func):
return wrapper
def get(self, source):
- assert source in self.store, f'{source} not found in the {self.name} registry'
+ assert source in self.store, f"{source} not found in the {self.name} registry"
target = self.store[source]
return target
@@ -29,4 +27,4 @@ def has(self, source):
return source in self.store
-meta_register = Registry('meta')
+meta_register = Registry("meta")
diff --git a/colossalai/auto_parallel/meta_profiler/shard_metainfo.py b/colossalai/auto_parallel/meta_profiler/shard_metainfo.py
index 0eee908b48b7..109b8a220ac7 100644
--- a/colossalai/auto_parallel/meta_profiler/shard_metainfo.py
+++ b/colossalai/auto_parallel/meta_profiler/shard_metainfo.py
@@ -2,20 +2,13 @@
import torch
-from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
- MemoryCost,
- OperationData,
- OperationDataType,
- ShardingStrategy,
- StrategiesVector,
- TrainCycleItem,
-)
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, ShardingStrategy, TrainCycleItem
from colossalai.tensor.sharding_spec import ShardingSpec
from .constants import INPLACE_MODULE, INPLACE_OPS, NO_SAVE_ACTIVATION
from .registry import meta_register
-__all__ = ['ShardMetaInfo']
+__all__ = ["ShardMetaInfo"]
class ShardMetaInfo:
@@ -76,10 +69,12 @@ def compute_sharded_opdata(self, operation_data: OperationData, sharding_spec: S
"""
if isinstance(sharding_spec, ShardingSpec):
- op_data = OperationData(name=operation_data.name,
- data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"),
- type=operation_data.type,
- logical_shape=operation_data.logical_shape)
+ op_data = OperationData(
+ name=operation_data.name,
+ data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"),
+ type=operation_data.type,
+ logical_shape=operation_data.logical_shape,
+ )
elif isinstance(sharding_spec, (list, tuple)):
data = operation_data.data
assert isinstance(data, (list, tuple)), f"Data Should be list or tuple, but got {type(data)}."
@@ -97,8 +92,9 @@ def compute_shard_metainfo(self):
"""
Compute meta info based on sharding strategy and the given target function.
"""
- assert meta_register.has(self._target.__class__) or meta_register.has(self._target), \
- f"Meta info for {self._target} is not registered."
+ assert meta_register.has(self._target.__class__) or meta_register.has(
+ self._target
+ ), f"Meta info for {self._target} is not registered."
if meta_register.has(self._target.__class__):
# module
meta_func = meta_register.get(self._target.__class__)
@@ -117,11 +113,11 @@ def compute_shard_metainfo(self):
# construct kwargs
if self.target in INPLACE_MODULE:
- kwargs = {'inplace': self.target.inplace}
+ kwargs = {"inplace": self.target.inplace}
elif self.target in INPLACE_OPS:
- kwargs = {'inplace': True}
+ kwargs = {"inplace": True}
else:
- kwargs = {'inplace': False}
+ kwargs = {"inplace": False}
# compute metainfo with meta_func
self.compute_cost, self.memory_cost, self.fwd_in, self.fwd_buffer, self.fwd_out = meta_func(*args, **kwargs)
diff --git a/colossalai/auto_parallel/offload/amp_optimizer.py b/colossalai/auto_parallel/offload/amp_optimizer.py
index 19d85b80dd3d..601bf2926d99 100644
--- a/colossalai/auto_parallel/offload/amp_optimizer.py
+++ b/colossalai/auto_parallel/offload/amp_optimizer.py
@@ -5,8 +5,8 @@
from torch.optim import Optimizer
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
+from colossalai.interface import OptimizerWrapper
from colossalai.logging import get_dist_logger
-from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils import get_current_device
from .base_offload_module import BaseOffloadModule
@@ -19,7 +19,7 @@ class OptimState(Enum):
UNSCALED = 1
-class AMPOptimizer(ColossalaiOptimizer):
+class AMPOptimizer(OptimizerWrapper):
"""
A wrapper for Optimizer.
Code reference: https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/optimizer/zero_optimizer.py
@@ -37,19 +37,20 @@ class AMPOptimizer(ColossalaiOptimizer):
norm_type (float, optional): norm_type used for `clip_grad_norm`.
"""
- def __init__(self,
- optimizer: Optimizer,
- module: BaseOffloadModule,
- initial_scale: float = 2**16,
- growth_factor: float = 2,
- backoff_factor: float = 0.5,
- growth_interval: int = 1000,
- hysteresis: int = 2,
- min_scale: float = 1,
- max_scale: float = 2**32,
- clipping_norm: float = 0.0,
- norm_type: float = 2.0):
-
+ def __init__(
+ self,
+ optimizer: Optimizer,
+ module: BaseOffloadModule,
+ initial_scale: float = 2**16,
+ growth_factor: float = 2,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 1000,
+ hysteresis: int = 2,
+ min_scale: float = 1,
+ max_scale: float = 2**32,
+ clipping_norm: float = 0.0,
+ norm_type: float = 2.0,
+ ):
super().__init__(optimizer)
self.module = module
@@ -69,19 +70,21 @@ def __init__(self,
self.__init__optimizer()
# Grad scaler
- self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
- min_scale=min_scale,
- growth_factor=growth_factor,
- backoff_factor=backoff_factor,
- growth_interval=growth_interval,
- hysteresis=hysteresis,
- max_scale=max_scale)
+ self.grad_scaler = DynamicGradScaler(
+ initial_scale=initial_scale,
+ min_scale=min_scale,
+ growth_factor=growth_factor,
+ backoff_factor=backoff_factor,
+ growth_interval=growth_interval,
+ hysteresis=hysteresis,
+ max_scale=max_scale,
+ )
self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=get_current_device())
self._logger = get_dist_logger()
def _set_grad_ptr(self):
for group in self.param_groups:
- for fake_param in group['params']:
+ for fake_param in group["params"]:
region = self.param_to_region[fake_param]
begin, end = self.param_to_range[fake_param]
@@ -92,7 +95,7 @@ def _set_grad_ptr(self):
def _update_fp16_params(self):
none_tensor = torch.empty([0])
for group in self.param_groups:
- for fake_param in group['params']:
+ for fake_param in group["params"]:
assert fake_param.grad is None
fake_param.data = none_tensor
self.param_to_region[fake_param].cpu_grad = None
@@ -130,10 +133,10 @@ def step(self, *args, **kwargs):
found_inf = self._check_overflow()
if found_inf:
- self.optim_state = OptimState.UNSCALED # no need to unscale grad
- self.grad_scaler.update(found_inf) # update gradient scaler
- self._logger.info(f'Found overflow. Skip step')
- self.zero_grad() # reset all gradients
+ self.optim_state = OptimState.UNSCALED # no need to unscale grad
+ self.grad_scaler.update(found_inf) # update gradient scaler
+ self._logger.info(f"Found overflow. Skip step")
+ self.zero_grad() # reset all gradients
self._update_fp16_params()
return
@@ -156,11 +159,10 @@ def backward(self, loss: torch.Tensor):
self.module.backward(loss)
def __init__optimizer(self):
-
for group in self.optim.param_groups:
fake_params_list = list()
- for param in group['params']:
+ for param in group["params"]:
region = self.region_manager.get_region(param)
fake_param = torch.nn.Parameter(torch.empty([0]))
self.param_to_range[fake_param] = region.param_to_range[param]
@@ -171,7 +173,7 @@ def __init__optimizer(self):
if param in self.optim.state:
self.optim.state[fake_param] = self.optim.state.pop(param)
- group['params'] = fake_params_list
+ group["params"] = fake_params_list
# Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors
diff --git a/colossalai/auto_parallel/offload/base_offload_module.py b/colossalai/auto_parallel/offload/base_offload_module.py
index 5b9f74b132f3..f5e8e31f5e97 100644
--- a/colossalai/auto_parallel/offload/base_offload_module.py
+++ b/colossalai/auto_parallel/offload/base_offload_module.py
@@ -22,7 +22,6 @@ class BaseOffloadModule:
"""
def __init__(self, model: nn.Module, region_manager: RegionManager, is_sync=True):
-
self.model = model
self.region_manager = region_manager
self.grad_hook_list = []
@@ -91,17 +90,16 @@ def _cast_buffers(self):
def parameters(self, recurse: bool = True):
return self.model.parameters(recurse)
- def named_parameters(self, prefix: str = '', recurse: bool = True):
+ def named_parameters(self, prefix: str = "", recurse: bool = True):
return self.model.named_parameters(prefix, recurse)
- def named_buffers(self, prefix: str = '', recurse: bool = True):
+ def named_buffers(self, prefix: str = "", recurse: bool = True):
return self.model.named_buffers(prefix, recurse)
def named_children(self):
return self.model.named_children()
- def named_modules(self,
- memo: Optional[Set[torch.nn.Module]] = None,
- prefix: str = '',
- remove_duplicate: bool = True):
+ def named_modules(
+ self, memo: Optional[Set[torch.nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True
+ ):
return self.model.named_modules(memo, prefix, remove_duplicate)
diff --git a/colossalai/auto_parallel/offload/mem_optimize.py b/colossalai/auto_parallel/offload/mem_optimize.py
index d56166dea982..74501c184518 100644
--- a/colossalai/auto_parallel/offload/mem_optimize.py
+++ b/colossalai/auto_parallel/offload/mem_optimize.py
@@ -14,11 +14,9 @@
from .util import GlobalRuntimeInfo, compute_act_peak_mem, compute_max_param_mem, compute_total_param_mem
-def memory_optimize(model: torch.nn.Module,
- inps: Dict[str, torch.Tensor],
- memory_budget: float = -1.0,
- solver_name: str = 'asyn'):
-
+def memory_optimize(
+ model: torch.nn.Module, inps: Dict[str, torch.Tensor], memory_budget: float = -1.0, solver_name: str = "asyn"
+):
model = model.cpu().half()
tracer = ColoTracer()
assert is_compatible_with_meta()
@@ -40,13 +38,13 @@ def memory_optimize(model: torch.nn.Module,
f"act_peak_mem={act_peak_mem:.3f} MB | max_param_mem={max_param_mem:.3f} MB | total_param_mem={total_param_mem:.3f}"
)
- if solver_name == 'syn':
+ if solver_name == "syn":
gm = runtime_syn_offload_apply_pass(gm, region_manager.region_list)
- elif solver_name == 'asyn':
+ elif solver_name == "asyn":
gm = runtime_asyn_offload_apply_pass(gm, region_manager.region_list)
else:
raise TypeError(f"Unknown solver name {solver_name}!")
gm.recompile()
- optimized_model = BaseOffloadModule(gm, region_manager, solver_name == 'syn')
+ optimized_model = BaseOffloadModule(gm, region_manager, solver_name == "syn")
return optimized_model
diff --git a/colossalai/auto_parallel/offload/region.py b/colossalai/auto_parallel/offload/region.py
index 819ffbd96eb1..ea92c714ce31 100644
--- a/colossalai/auto_parallel/offload/region.py
+++ b/colossalai/auto_parallel/offload/region.py
@@ -55,13 +55,13 @@ def init_param_data(self, pre_alloc_tensor: torch.Tensor = None):
Map the parameters in the region to a contiguous memory space.
"""
- self.fp16_data = torch.zeros(self.param_num, dtype=torch.half, device='cuda')
+ self.fp16_data = torch.zeros(self.param_num, dtype=torch.half, device="cuda")
offset = 0
for param in self.fp16_params:
param.data = param.data.cuda()
p_num = param.data.numel()
- self.fp16_data[offset:offset + p_num].copy_(param.data.flatten())
- param.data = self.fp16_data[offset:offset + p_num].view(param.data.shape)
+ self.fp16_data[offset : offset + p_num].copy_(param.data.flatten())
+ param.data = self.fp16_data[offset : offset + p_num].view(param.data.shape)
self.param_to_range[param] = (offset, offset + p_num)
offset += p_num
@@ -83,7 +83,7 @@ def move_param_to_cuda(self):
self.temp_fp32_data.record_stream(torch.cuda.current_stream())
if not self.in_mem_pool_flag:
alloc_storage(self.fp16_data)
- self.fp16_data[:self.param_num].copy_(self.temp_fp32_data)
+ self.fp16_data[: self.param_num].copy_(self.temp_fp32_data)
self.fp16_data.record_stream(torch.cuda.current_stream())
self.__update_params_ptr()
@@ -94,7 +94,7 @@ def move_grad_to_cpu(self):
"""
self.cpu_grad = torch.empty(self.param_num, dtype=torch.half, pin_memory=True)
- self.cpu_grad.copy_(self.fp16_data[:self.param_num], non_blocking=True)
+ self.cpu_grad.copy_(self.fp16_data[: self.param_num], non_blocking=True)
self.fp16_data.record_stream(torch.cuda.current_stream())
if not self.in_mem_pool_flag:
self.free_cuda_data()
diff --git a/colossalai/auto_parallel/offload/region_manager.py b/colossalai/auto_parallel/offload/region_manager.py
index 30bfaf00d493..146dd267967d 100644
--- a/colossalai/auto_parallel/offload/region_manager.py
+++ b/colossalai/auto_parallel/offload/region_manager.py
@@ -1,10 +1,11 @@
-from typing import List, Any, Dict, Tuple
+from typing import Any, Dict, List, Tuple
+
import torch
from torch.fx import Graph, Node
+from .region import Region
from .solver import SolverFactory
from .training_simulator import TrainingSimulator
-from .region import Region
from .util import NodeInfo
@@ -19,14 +20,9 @@ class RegionManager:
cnode (List[str], optional): Common node List, should be the subset of input.
"""
- def __init__(self,
- graph: Graph,
- solver_name: str = 'asyn',
- memory_budget: float = -1.0,
- cnode: List[str] = None):
-
+ def __init__(self, graph: Graph, solver_name: str = "asyn", memory_budget: float = -1.0, cnode: List[str] = None):
self.graph = graph
- assert graph.owning_module is not None, 'The given graph is not associated with a owning_module'
+ assert graph.owning_module is not None, "The given graph is not associated with a owning_module"
self.root_module = self.graph.owning_module
self.nodes = list(graph.nodes)
self.cnode = cnode
@@ -39,7 +35,7 @@ def __init__(self,
self.memory_budget = memory_budget
self.solver_name = solver_name
- self.require_pool: bool = solver_name == 'asyn'
+ self.require_pool: bool = solver_name == "asyn"
self.reg_to_block: Dict[int, int] = dict()
@@ -61,22 +57,19 @@ def _build_regions(self):
self._post_process(solver.best_ts)
def _pre_process(self):
-
init_region_list = self._linearize_graph()
if len(self.shared_region_pairs) > 1:
- raise NotImplementedError(
- 'The current version only considers at most one pair of parameter sharing.')
+ raise NotImplementedError("The current version only considers at most one pair of parameter sharing.")
elif len(self.shared_region_pairs) == 1:
shared_regs = self.shared_region_pairs[0]
- assert shared_regs[0].shared_rid == shared_regs[1].r_id \
- and shared_regs[1].shared_rid == shared_regs[0].r_id
+ assert shared_regs[0].shared_rid == shared_regs[1].r_id and shared_regs[1].shared_rid == shared_regs[0].r_id
fst_id = shared_regs[0].r_id
lst_id = shared_regs[1].r_id
- regs_left_out = init_region_list[:fst_id + 1]
+ regs_left_out = init_region_list[: fst_id + 1]
regs_right_out = init_region_list[lst_id:]
- hold_regs = init_region_list[fst_id + 1:lst_id]
+ hold_regs = init_region_list[fst_id + 1 : lst_id]
else:
regs_left_out = []
regs_right_out = []
@@ -122,12 +115,9 @@ def _early_region_placement(self, ts: TrainingSimulator):
it may not find a suitable region placement strategy for the given execution flow.
"""
- reg_flow = torch.cat(
- [ts.fwd_reg_flow, ts.bwd_reg_flow], dim=0)
- mem_block_num = torch.max(
- torch.sum(reg_flow[:, self.rid_in_pool], dim=1))
- coexist_matrix = torch.logical_or(
- ts.fwd_reg_flow, ts.bwd_reg_flow)
+ reg_flow = torch.cat([ts.fwd_reg_flow, ts.bwd_reg_flow], dim=0)
+ mem_block_num = torch.max(torch.sum(reg_flow[:, self.rid_in_pool], dim=1))
+ coexist_matrix = torch.logical_or(ts.fwd_reg_flow, ts.bwd_reg_flow)
block_to_regs = {}
for block_idx in range(mem_block_num):
@@ -135,8 +125,7 @@ def _early_region_placement(self, ts: TrainingSimulator):
for reg in self.region_list:
if reg.r_id in self.rid_in_pool:
cur_reg_appears = coexist_matrix[:, reg.r_id]
- cur_reg_coexists = torch.sum(
- coexist_matrix[cur_reg_appears], dim=0).bool()
+ cur_reg_coexists = torch.sum(coexist_matrix[cur_reg_appears], dim=0).bool()
for block_idx in range(mem_block_num):
if not any(cur_reg_coexists[block_to_regs[block_idx]]):
block_to_regs[block_idx].append(reg.r_id)
@@ -145,9 +134,12 @@ def _early_region_placement(self, ts: TrainingSimulator):
if reg.r_id not in self.reg_to_block:
raise NotImplementedError(
- f'can not find a block from the memory pool to store parameters of the region')
- self.memory_pool = torch.chunk(torch.zeros(int(
- mem_block_num * self.mem_block_size / 2), dtype=torch.half, device='cuda'), chunks=int(mem_block_num))
+ f"can not find a block from the memory pool to store parameters of the region"
+ )
+ self.memory_pool = torch.chunk(
+ torch.zeros(int(mem_block_num * self.mem_block_size / 2), dtype=torch.half, device="cuda"),
+ chunks=int(mem_block_num),
+ )
def _merge_small_regions(self, orig_reg_list: List[Region]) -> List[Region]:
"""
@@ -178,10 +170,9 @@ def _merge_small_regions(self, orig_reg_list: List[Region]) -> List[Region]:
return region_list
- def _search_block_size(self,
- region_list: List[Region],
- search_interval_byte: int = 1024,
- search_range_byte: int = 128 * 1024 ** 2) -> int:
+ def _search_block_size(
+ self, region_list: List[Region], search_interval_byte: int = 1024, search_range_byte: int = 128 * 1024**2
+ ) -> int:
"""
Search for a suitable memory block size.
@@ -208,11 +199,10 @@ def _get_wasted_mem(size_list: List[int], blk_size: int):
acc_wasted += blk_size - left
return acc_wasted
- param_size_list = [
- region.param_size for region in region_list if region.r_id == region.shared_rid]
+ param_size_list = [region.param_size for region in region_list if region.r_id == region.shared_rid]
start_size = max(param_size_list)
- min_mem_waste = float('+inf')
+ min_mem_waste = float("+inf")
best_block_size = start_size
for block_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte):
@@ -229,7 +219,7 @@ def _init_region_data(self):
Initialize region data, which maps the parameters in the region to a contiguous memory space.
"""
- self.temp_fp32_data = torch.zeros(self.max_param_num, device='cuda', dtype=torch.float32)
+ self.temp_fp32_data = torch.zeros(self.max_param_num, device="cuda", dtype=torch.float32)
for region in self.region_list:
pre_alloc_tensor = None
@@ -244,8 +234,7 @@ def _init_region_data(self):
region.fp16_data = shared_region.fp16_data
region.fp32_data = shared_region.fp32_data
region.param_to_range = shared_region.param_to_range
- region.temp_fp32_data = self.temp_fp32_data[:region.param_num].detach(
- )
+ region.temp_fp32_data = self.temp_fp32_data[: region.param_num].detach()
torch.cuda.empty_cache()
@@ -259,13 +248,14 @@ def _process_shared_region(self):
former_reg, latter_reg = self.shared_region_pairs[0]
assert latter_reg.param_num >= former_reg.param_num
embedding_node = former_reg.nodes[-1]
- assert embedding_node.op == 'call_module' and isinstance(
- self.root_module.get_submodule(embedding_node.target), torch.nn.Embedding)
+ assert embedding_node.op == "call_module" and isinstance(
+ self.root_module.get_submodule(embedding_node.target), torch.nn.Embedding
+ )
if latter_reg.param_num > former_reg.param_num:
for idx, n in enumerate(latter_reg.nodes):
- if (n.op == 'call_module' and isinstance(self.root_module.get_submodule(n.target),
- torch.nn.Linear)) or \
- (n.op == 'call_function' and n.target is torch.nn.functional.linear):
+ if (
+ n.op == "call_module" and isinstance(self.root_module.get_submodule(n.target), torch.nn.Linear)
+ ) or (n.op == "call_function" and n.target is torch.nn.functional.linear):
cut_node_idx = idx + 1
break
assert len(latter_reg.fp16_params) == 2
@@ -273,7 +263,7 @@ def _process_shared_region(self):
for p in new_reg.fp16_params:
self.param_region_map[p] = new_reg
self.region_list.insert(new_reg.r_id, new_reg)
- for reg in self.region_list[new_reg.r_id + 1:]:
+ for reg in self.region_list[new_reg.r_id + 1 :]:
reg.r_id += 1
latter_reg.shared_rid = former_reg.r_id
former_reg.shared_rid = latter_reg.r_id
@@ -344,8 +334,8 @@ def _maybe_param_comp_start() -> bool:
target = n.target
submod = self.root_module.get_submodule(target)
if (
- len(list(submod.named_parameters(recurse=False))) != 0
- or len(list(submod.named_buffers(recurse=False))) != 0
+ len(list(submod.named_parameters(recurse=False))) != 0
+ or len(list(submod.named_buffers(recurse=False))) != 0
):
label = True
@@ -362,14 +352,12 @@ def _is_param_comp_end() -> bool:
"""
def _is_inplace(n: Node):
- """Get the inplace argument from ``torch.fx.Node``
- """
+ """Get the inplace argument from ``torch.fx.Node``"""
inplace = False
if n.op == "call_function":
inplace = n.kwargs.get("inplace", False)
elif n.op == "call_module":
- inplace = getattr(n.graph.owning_module.get_submodule(
- n.target), "inplace", False)
+ inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False)
return inplace
label = False
@@ -378,28 +366,30 @@ def _is_inplace(n: Node):
target = n.target
submod = self.root_module.get_submodule(target)
if (
- len(list(submod.named_parameters(recurse=False))) != 0
- or len(list(submod.named_buffers(recurse=False))) != 0
+ len(list(submod.named_parameters(recurse=False))) != 0
+ or len(list(submod.named_buffers(recurse=False))) != 0
):
label = True
elif n.op == "call_function":
label = any(map(lambda x: x.name in self.only_param_ops, n.all_input_nodes)) and any(
- map(lambda x: x.name not in self.only_param_ops and not _is_cop(n.target), n.all_input_nodes))
+ map(lambda x: x.name not in self.only_param_ops and not _is_cop(n.target), n.all_input_nodes)
+ )
return label and not sum([v for _, v in param_op_deps.items()]) and not any(map(_is_inplace, n.users))
def _exception_node_handling():
# TODO meta info prop bug
- if n.name.__contains__("transpose") and n.meta['fwd_out'][0].dim() <= 2:
- n.meta['fwd_out'] = []
+ if n.name.__contains__("transpose") and n.meta["fwd_out"][0].dim() <= 2:
+ n.meta["fwd_out"] = []
# make sure that item in cnode is valid
if self.cnode:
for name in self.cnode:
try:
- assert next(node for node in self.graph.nodes if node.name == name).op == "placeholder", \
- f"Common node {name} is not an input of the model."
+ assert (
+ next(node for node in self.graph.nodes if node.name == name).op == "placeholder"
+ ), f"Common node {name} is not an input of the model."
except StopIteration:
raise ValueError(f"Common node name {name} not in graph.")
else:
@@ -428,8 +418,8 @@ def _exception_node_handling():
ns = []
border_n_idx = region.nodes.index(act_n)
if border_n_idx < len(region.nodes):
- ns = region.nodes[border_n_idx + 1:]
- region.nodes = region.nodes[:border_n_idx + 1]
+ ns = region.nodes[border_n_idx + 1 :]
+ region.nodes = region.nodes[: border_n_idx + 1]
region_list.append(region)
region_id += 1
region = Region(r_id=region_id)
@@ -448,19 +438,21 @@ def _exception_node_handling():
region = Region(r_id=region_id)
# propagate common node attr if possible
- if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.cnode
- ]) or _is_cop(n.target):
+ if len(n.all_input_nodes) == len(
+ [node for node in n.all_input_nodes if node.name in self.cnode]
+ ) or _is_cop(n.target):
self.cnode.append(n.name)
else:
- deps[n] = len(
- [user for user in n.users if user.op != "output"])
+ deps[n] = len([user for user in n.users if user.op != "output"])
# propagate param node attr if possible
- if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.only_param_ops
- ]) or n.op == "get_attr":
+ if (
+ len(n.all_input_nodes)
+ == len([node for node in n.all_input_nodes if node.name in self.only_param_ops])
+ or n.op == "get_attr"
+ ):
self.only_param_ops.append(n.name)
- param_op_deps[n] = len(
- [user for user in n.users if user.op != "output"])
+ param_op_deps[n] = len([user for user in n.users if user.op != "output"])
# record last activation node
if _is_act(n._meta_data):
@@ -472,19 +464,16 @@ def _exception_node_handling():
return region_list
def _set_node_and_region_info(self, node_id: int, cur_n: Node, cur_reg: Region):
-
cur_n.node_info = NodeInfo(node_id)
- if cur_n.op == 'call_module':
+ if cur_n.op == "call_module":
target = cur_n.target
submod = self.root_module.get_submodule(target)
for p in list(submod.parameters(recurse=False)):
-
if p in self.param_region_map:
cur_reg.shared_rid = self.param_region_map[p].r_id
self.param_region_map[p].shared_rid = cur_reg.r_id
- self.shared_region_pairs.append(
- (self.param_region_map[p], cur_reg))
+ self.shared_region_pairs.append((self.param_region_map[p], cur_reg))
else:
self.param_region_map[p] = cur_reg
@@ -499,12 +488,10 @@ def _set_node_and_region_info(self, node_id: int, cur_n: Node, cur_reg: Region):
attr_itr = getattr(attr_itr, atom)
if isinstance(attr_itr, torch.nn.Parameter):
-
if attr_itr in self.param_region_map:
cur_reg.shared_rid = self.param_region_map[attr_itr].r_id
self.param_region_map[attr_itr].shared_rid = cur_reg.r_id
- self.shared_region_pairs.append(
- (self.param_region_map[attr_itr], cur_reg))
+ self.shared_region_pairs.append((self.param_region_map[attr_itr], cur_reg))
else:
self.param_region_map[attr_itr] = cur_reg
diff --git a/colossalai/auto_parallel/offload/runtime.py b/colossalai/auto_parallel/offload/runtime.py
index 764ac608826b..cc790dfb0891 100644
--- a/colossalai/auto_parallel/offload/runtime.py
+++ b/colossalai/auto_parallel/offload/runtime.py
@@ -22,13 +22,13 @@ class SynPreFwdPostBwdOP(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, fwd_info, bwd_info):
ctx.bwd_info = bwd_info
- d2h_rid = fwd_info.get('d2h_rid', None)
+ d2h_rid = fwd_info.get("d2h_rid", None)
if d2h_rid is not None:
free_region = GlobalRuntimeInfo().region_list[d2h_rid]
assert isinstance(free_region, Region)
free_region.free_cuda_data()
- h2d_rid = fwd_info.get('h2d_rid', None)
+ h2d_rid = fwd_info.get("h2d_rid", None)
if h2d_rid is not None:
h2d_region = GlobalRuntimeInfo().region_list[h2d_rid]
assert isinstance(h2d_region, Region)
@@ -38,8 +38,7 @@ def forward(ctx, input_, fwd_info, bwd_info):
@staticmethod
def backward(ctx, grad_output):
-
- h2d_rid = ctx.bwd_info.get('h2d_rid', None)
+ h2d_rid = ctx.bwd_info.get("h2d_rid", None)
if h2d_rid is not None:
pref_region = GlobalRuntimeInfo().region_list[h2d_rid]
assert isinstance(pref_region, Region)
@@ -64,13 +63,13 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function):
def forward(ctx, input_, fwd_info, bwd_info):
ctx.bwd_info = bwd_info
- sync_rid = fwd_info.get('sync_rid', None)
+ sync_rid = fwd_info.get("sync_rid", None)
if sync_rid is not None:
prefetch_event = GlobalRuntimeInfo().fwd_prefetch_event_map.get(sync_rid, None)
if prefetch_event:
prefetch_event.wait()
- h2d_rid = fwd_info.get('h2d_rid', None)
+ h2d_rid = fwd_info.get("h2d_rid", None)
if h2d_rid is not None:
pref_region = GlobalRuntimeInfo().region_list[h2d_rid]
assert isinstance(pref_region, Region)
@@ -87,8 +86,7 @@ def forward(ctx, input_, fwd_info, bwd_info):
@staticmethod
def backward(ctx, grad_output):
-
- sync_rid = ctx.bwd_info.get('sync_rid', None)
+ sync_rid = ctx.bwd_info.get("sync_rid", None)
if sync_rid is not None:
wait_region = GlobalRuntimeInfo().region_list[sync_rid]
assert isinstance(wait_region, Region)
@@ -98,7 +96,7 @@ def backward(ctx, grad_output):
else:
wait_region.move_param_to_cuda()
- h2d_rid = ctx.bwd_info.get('h2d_rid', None)
+ h2d_rid = ctx.bwd_info.get("h2d_rid", None)
if h2d_rid is not None:
pref_region = GlobalRuntimeInfo().region_list[h2d_rid]
assert isinstance(pref_region, Region)
@@ -114,7 +112,7 @@ def backward(ctx, grad_output):
def convert_fwd_upload_bwd_offload_to_action(tensor, fwd_info, bwd_info):
- '''
+ """
Convert Upload and Offload operation into runtime action.
Argument:
@@ -123,14 +121,14 @@ def convert_fwd_upload_bwd_offload_to_action(tensor, fwd_info, bwd_info):
that need to be uploaded, or freed during forward pass.
bwd_info(dict): information dict, which contains region indices
that need to be uploaded during backward pass.
- '''
+ """
with torch._C.DisableTorchFunction():
ret = SynPreFwdPostBwdOP.apply(tensor, fwd_info, bwd_info)
return ret
def convert_fwd_prefetch_bwd_offload_to_action(tensor, fwd_info, bwd_info):
- '''
+ """
Convert Prefetch and Offload operation into runtime action.
Argument:
@@ -139,7 +137,7 @@ def convert_fwd_prefetch_bwd_offload_to_action(tensor, fwd_info, bwd_info):
that need to be prefetched, waited, or freed during forward pass.
bwd_info(dict): information dict, which contains region indices
that need to be prefetched or waited during backward pass.
- '''
+ """
with torch._C.DisableTorchFunction():
ret = AsynPreFwdPostBwdOP.apply(tensor, fwd_info, bwd_info)
return ret
@@ -176,22 +174,22 @@ def runtime_syn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[R
# forward upload
fwd_info = {}
if requires_upload_p_in_fwd(region_list[region.shared_rid]):
- fwd_info['h2d_rid'] = region.r_id
+ fwd_info["h2d_rid"] = region.r_id
# forward offload
if r_idx > 0 and region_list[r_idx - 1].need_offload:
- fwd_info['d2h_rid'] = r_idx - 1
+ fwd_info["d2h_rid"] = r_idx - 1
bwd_info = {}
# backward upload
if r_idx > 0 and region_list[r_idx - 1].need_offload:
- bwd_info['h2d_rid'] = region_list[r_idx - 1].r_id
+ bwd_info["h2d_rid"] = region_list[r_idx - 1].r_id
if fwd_info or bwd_info:
with mod_graph.inserting_after(last_inp_node):
- new_node = mod_graph.create_node('call_function',
- convert_fwd_upload_bwd_offload_to_action,
- args=(last_inp_node, fwd_info, bwd_info))
+ new_node = mod_graph.create_node(
+ "call_function", convert_fwd_upload_bwd_offload_to_action, args=(last_inp_node, fwd_info, bwd_info)
+ )
replace_node_users(last_inp_node, new_node)
last_inp_node = region.nodes[-1]
@@ -210,9 +208,9 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[
first_region_with_p = [region for region in region_list if region.param_size][0]
fwd_info = {"h2d_rid": first_region_with_p.r_id}
with mod_graph.inserting_after(last_inp_node):
- upload_apply_node = mod_graph.create_node('call_function',
- convert_fwd_upload_bwd_offload_to_action,
- args=(last_inp_node, fwd_info, {}))
+ upload_apply_node = mod_graph.create_node(
+ "call_function", convert_fwd_upload_bwd_offload_to_action, args=(last_inp_node, fwd_info, {})
+ )
replace_node_users(last_inp_node, upload_apply_node)
last_inp_node = upload_apply_node
@@ -220,37 +218,39 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[
# forward prefetch
fwd_info = {}
if region.param_size:
- fwd_info['sync_rid'] = region.r_id
+ fwd_info["sync_rid"] = region.r_id
fwd_prefetch_region = region.fwd_prefetch_region
if fwd_prefetch_region and requires_upload_p_in_fwd(region_list[fwd_prefetch_region.shared_rid]):
- fwd_info['h2d_rid'] = fwd_prefetch_region.r_id
+ fwd_info["h2d_rid"] = fwd_prefetch_region.r_id
# forward offload
if r_idx > 0 and region_list[r_idx - 1].need_offload:
- fwd_info['d2h_rid'] = r_idx - 1
+ fwd_info["d2h_rid"] = r_idx - 1
bwd_info = {}
# backward prefetch
if r_idx > 0 and region_list[r_idx - 1].need_offload:
- bwd_info['sync_rid'] = r_idx - 1
+ bwd_info["sync_rid"] = r_idx - 1
if r_idx > 0 and region_list[r_idx - 1].bwd_prefetch_region:
- bwd_info['h2d_rid'] = region_list[r_idx - 1].bwd_prefetch_region.r_id
+ bwd_info["h2d_rid"] = region_list[r_idx - 1].bwd_prefetch_region.r_id
if fwd_info or bwd_info:
with mod_graph.inserting_after(last_inp_node):
- new_node = mod_graph.create_node('call_function',
- convert_fwd_prefetch_bwd_offload_to_action,
- args=(last_inp_node, fwd_info, bwd_info))
+ new_node = mod_graph.create_node(
+ "call_function",
+ convert_fwd_prefetch_bwd_offload_to_action,
+ args=(last_inp_node, fwd_info, bwd_info),
+ )
replace_node_users(last_inp_node, new_node)
last_inp_node = region.nodes[-1]
if region.bwd_prefetch_region:
- bwd_info = {'h2d_rid': region.bwd_prefetch_region.r_id}
+ bwd_info = {"h2d_rid": region.bwd_prefetch_region.r_id}
with mod_graph.inserting_after(last_inp_node):
- new_node = mod_graph.create_node('call_function',
- convert_fwd_prefetch_bwd_offload_to_action,
- args=(last_inp_node, {}, bwd_info))
+ new_node = mod_graph.create_node(
+ "call_function", convert_fwd_prefetch_bwd_offload_to_action, args=(last_inp_node, {}, bwd_info)
+ )
replace_node_users(last_inp_node, new_node)
# gm.graph.print_tabular()
return gm
diff --git a/colossalai/auto_parallel/offload/solver.py b/colossalai/auto_parallel/offload/solver.py
index 161f7ff86898..a6b4904f2617 100644
--- a/colossalai/auto_parallel/offload/solver.py
+++ b/colossalai/auto_parallel/offload/solver.py
@@ -1,6 +1,6 @@
import time
-from typing import List, Dict, Type
from abc import ABC, abstractmethod
+from typing import Dict, List, Type
NOT_NVML = False
try:
@@ -10,10 +10,11 @@
import torch
from torch.fx.node import Node
+
from colossalai.utils.cuda import get_current_device
-from .training_simulator import TrainingSimulator, SynTrainingSimulator, AsynTrainingSimulator
from .region import Region
+from .training_simulator import AsynTrainingSimulator, SynTrainingSimulator, TrainingSimulator
from .util import NodeInfo, NvDevicePower
@@ -49,19 +50,14 @@ class Solver(ABC):
It is used to reduce the memory budget. Due to some errors in the estimation of peak memory and execution time.
"""
- def __init__(self,
- region_list: List[Region],
- memory_budget: float = -1.0,
- error_factor: float = 0.95) -> None:
-
+ def __init__(self, region_list: List[Region], memory_budget: float = -1.0, error_factor: float = 0.95) -> None:
self.region_list = region_list
self.error_factor: float = error_factor
if memory_budget > 0:
self.memory_budget = memory_budget * self.error_factor
else:
- self.memory_budget = torch.cuda.get_device_properties(
- get_current_device()).total_memory * self.error_factor
+ self.memory_budget = torch.cuda.get_device_properties(get_current_device()).total_memory * self.error_factor
self.link_to_bandwidth: Dict[str, Dict[float, float]] = self._profile_bandwidth()
self.comp_power: float = self._extract_computing_power()
@@ -94,7 +90,7 @@ def _compute_offload_profit(self, total_mem_saving: float, peak_mem_saving: floa
if extra_cost == 0:
# means data transfer overhead can be completely overlapped
- return (float('inf'), total_mem_saving, peak_mem_saving)
+ return (float("inf"), total_mem_saving, peak_mem_saving)
return (total_mem_saving / extra_cost, total_mem_saving, peak_mem_saving)
def _compare_profit(self, profit_a: tuple, profit_b: tuple) -> bool:
@@ -122,9 +118,7 @@ def _update_state(self, best_ts: TrainingSimulator):
self.best_ts = best_ts
self._update_node_mem_info(best_ts.fwd_node_mem, best_ts.bwd_node_mem)
- def _update_node_mem_info(self,
- fwd_mem_info: Dict[Node, float],
- bwd_mem_info: Dict[Node, float]):
+ def _update_node_mem_info(self, fwd_mem_info: Dict[Node, float], bwd_mem_info: Dict[Node, float]):
"""
Update the runtime memory information of the node.
@@ -134,12 +128,10 @@ def _update_node_mem_info(self,
"""
for node, mem in fwd_mem_info.items():
- assert hasattr(node, 'node_info') and isinstance(
- node.node_info, NodeInfo)
+ assert hasattr(node, "node_info") and isinstance(node.node_info, NodeInfo)
node.node_info.runtime_fwd_mem = mem
for node, mem in bwd_mem_info.items():
- assert hasattr(node, 'node_info') and isinstance(
- node.node_info, NodeInfo)
+ assert hasattr(node, "node_info") and isinstance(node.node_info, NodeInfo)
node.node_info.runtime_bwd_mem = mem
def _extract_computing_power(self):
@@ -159,12 +151,12 @@ def _extract_computing_power(self):
return NvDevicePower.RTX3080_FP16 * units
elif device_name.__contains__("RTX 3090"):
return NvDevicePower.RTX3090_FP16 * units
- elif device_name.__contains__('V100'):
+ elif device_name.__contains__("V100"):
return NvDevicePower.V100_FP16 * units
elif device_name.__contains__("A100"):
return NvDevicePower.A100_FP16 * units
else:
- raise TypeError(f'Unknown NVIDIA GPU device name {device_name}')
+ raise TypeError(f"Unknown NVIDIA GPU device name {device_name}")
def _profile_bandwidth(self):
"""
@@ -172,9 +164,9 @@ def _profile_bandwidth(self):
using data volumes ranging from 1KB to 1GB.
"""
- print('profiling bandwidth ......')
+ print("profiling bandwidth ......")
link_to_bandwidth = {}
- links = ['h2d', 'd2h']
+ links = ["h2d", "d2h"]
for link in links:
t_size = 1024
@@ -182,24 +174,22 @@ def _profile_bandwidth(self):
# from 1KB to 1GB
for i in range(21):
- if link == 'h2d':
- src_tensor = torch.ones(
- int(t_size), dtype=torch.int8, pin_memory=True)
- dst_tensor = torch.ones(
- (int(t_size)), dtype=torch.int8, device='cuda')
- elif link == 'd2h':
- src_tensor = torch.ones(
- int(t_size), dtype=torch.int8, device='cuda')
- dst_tensor = torch.ones(
- (int(t_size)), dtype=torch.int8, pin_memory=True)
+ if link == "h2d":
+ src_tensor = torch.ones(int(t_size), dtype=torch.int8, pin_memory=True)
+ dst_tensor = torch.ones((int(t_size)), dtype=torch.int8, device="cuda")
+ elif link == "d2h":
+ src_tensor = torch.ones(int(t_size), dtype=torch.int8, device="cuda")
+ dst_tensor = torch.ones((int(t_size)), dtype=torch.int8, pin_memory=True)
def func():
dst_tensor.copy_(src_tensor)
size_to_bandwidth[t_size] = t_size / benchmark_func(func, number=5, repeat=3)
- print(f'size: {t_size / 1024 ** 2:.3f} MB, '
- f'{src_tensor.device.type}-to-{dst_tensor.device.type} '
- f'bandwidth: {size_to_bandwidth[t_size] / 1024 ** 3:.3f} GB/s')
+ print(
+ f"size: {t_size / 1024 ** 2:.3f} MB, "
+ f"{src_tensor.device.type}-to-{dst_tensor.device.type} "
+ f"bandwidth: {size_to_bandwidth[t_size] / 1024 ** 3:.3f} GB/s"
+ )
t_size *= 2
@@ -208,10 +198,7 @@ def func():
class SynGreedySolver(Solver):
-
- def __init__(self,
- region_list: List[Region],
- memory_budget: float = -1.0) -> None:
+ def __init__(self, region_list: List[Region], memory_budget: float = -1.0) -> None:
super().__init__(region_list, memory_budget)
self.best_ts: SynTrainingSimulator = None
@@ -258,7 +245,8 @@ def _call_solver(self):
else:
raise NotImplementedError(
f"can't find the offload strategy met the memory budget {self.memory_budget / 1024 ** 2} MB, "
- f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!")
+ f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!"
+ )
def _call_solver_l2l(self):
"""
@@ -270,7 +258,6 @@ def _call_solver_l2l(self):
region.is_syn = True
def _try_to_offload(self, offload_region: Region):
-
# record previous information
orig_need_offload = offload_region.need_offload
assert not orig_need_offload
@@ -297,23 +284,17 @@ def _eval_one_choice(self, offload_region: Region):
ts = SynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth)
ts.execute()
- extra_comm_cost = 2.0 * \
- ts._get_communication_overhead('h2d', offload_region.param_size)
+ extra_comm_cost = 2.0 * ts._get_communication_overhead("h2d", offload_region.param_size)
# the shared region needs to be moved twice
if offload_region.r_id < offload_region.shared_rid:
extra_comm_cost *= 2.0
- profit = self._compute_offload_profit(
- ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost)
+ profit = self._compute_offload_profit(ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost)
return ts, profit
class AsynGreedySolver(Solver):
-
- def __init__(self,
- region_list: List[Region],
- memory_budget: float = -1.0,
- search_window_size: int = 3):
+ def __init__(self, region_list: List[Region], memory_budget: float = -1.0, search_window_size: int = 3):
super().__init__(region_list, memory_budget)
self.search_window_size = search_window_size
@@ -331,7 +312,7 @@ def _init_state(self):
ts = AsynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth)
ts.execute()
self._update_state(ts)
- print("init peak memory", self.best_ts.peak_mem / 1024 ** 2, "MB")
+ print("init peak memory", self.best_ts.peak_mem / 1024**2, "MB")
def _call_solver(self):
"""
@@ -358,18 +339,17 @@ def _call_solver(self):
best_pref_ts = None
# search when to prefetch the region offloaded
- for host_region in self.region_list[region.r_id + 1:region.r_id + 1 + self.search_window_size]:
+ for host_region in self.region_list[region.r_id + 1 : region.r_id + 1 + self.search_window_size]:
if host_region.bwd_prefetch_region is not None:
continue
- temp_ts, profit = self._try_to_offload(
- host_region, region)
+ temp_ts, profit = self._try_to_offload(host_region, region)
if self._compare_profit(profit, max_prefetch_profit):
region_to_region_map[region.r_id] = host_region
max_prefetch_profit = profit
best_pref_ts = temp_ts
- if profit[0] == float('inf'):
+ if profit[0] == float("inf"):
break
if self._compare_profit(max_prefetch_profit, max_offload_profit):
@@ -392,7 +372,8 @@ def _call_solver(self):
else:
raise NotImplementedError(
f"can't find the offload strategy met the memory budget {self.memory_budget / 1024 ** 2} MB, "
- f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!")
+ f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!"
+ )
region_to_region_map.clear()
@@ -452,7 +433,6 @@ def _repair_strategy(self):
peak_mem_saving = 0
while len(self.region_to_region_map) and peak_mem_saving <= 0:
-
max_profit = (0,)
best_ts = None
undo_host_region = None
@@ -464,8 +444,7 @@ def _repair_strategy(self):
assert offload_region.need_offload
assert not offload_region.is_syn
- ts, profit = self._try_convert_to_syn_upload(host_region,
- offload_region)
+ ts, profit = self._try_convert_to_syn_upload(host_region, offload_region)
if self._compare_profit(profit, max_profit):
undo_host_region = host_region
@@ -474,7 +453,7 @@ def _repair_strategy(self):
best_ts = ts
if best_ts is None:
- raise NotImplementedError('repair error!')
+ raise NotImplementedError("repair error!")
assert not undo_offload_region.is_syn
undo_offload_region.is_syn = True
@@ -500,17 +479,13 @@ def _eval_one_choice(self):
ts.execute()
extra_comm_cost = max(ts.iter_end_time - self.best_ts.iter_end_time, 0)
- profit = self._compute_offload_profit(
- ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost)
+ profit = self._compute_offload_profit(ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost)
return ts, profit
class SolverFactory:
- solvers: Dict[str, Type[Solver]] = {
- 'syn': SynGreedySolver,
- 'asyn': AsynGreedySolver
- }
+ solvers: Dict[str, Type[Solver]] = {"syn": SynGreedySolver, "asyn": AsynGreedySolver}
@staticmethod
def create(solver_name: str) -> Type[Solver]:
diff --git a/colossalai/auto_parallel/offload/training_simulator.py b/colossalai/auto_parallel/offload/training_simulator.py
index de58023ec2d6..728d8daf9a46 100644
--- a/colossalai/auto_parallel/offload/training_simulator.py
+++ b/colossalai/auto_parallel/offload/training_simulator.py
@@ -1,7 +1,7 @@
import bisect
-from typing import List, Dict
-from collections import OrderedDict
from abc import ABC, abstractmethod
+from collections import OrderedDict
+from typing import Dict, List
from torch.fx.node import Node
@@ -26,10 +26,7 @@ class TrainingSimulator(ABC):
link_to_bw (Dict[str, Dict[float, float]]): communication links and the corresponding bandwidth.
"""
- def __init__(self,
- region_list: List[Region],
- comp_power: float,
- link_to_bw: Dict[str, Dict[float, float]]) -> None:
+ def __init__(self, region_list: List[Region], comp_power: float, link_to_bw: Dict[str, Dict[float, float]]) -> None:
self.region_list = region_list
self.region_num = len(region_list)
@@ -87,11 +84,7 @@ def _get_computing_overhead(self, flop: float) -> float:
class SynTrainingSimulator(TrainingSimulator):
-
- def __init__(self,
- region_list: List[Region],
- comp_power: float,
- link_to_bw: Dict[str, Dict[float, float]]) -> None:
+ def __init__(self, region_list: List[Region], comp_power: float, link_to_bw: Dict[str, Dict[float, float]]) -> None:
super().__init__(region_list, comp_power, link_to_bw)
def execute(self):
@@ -115,8 +108,7 @@ def _eval_fwd_mem_per_region(self, region: Region):
self.runtime_mem += region.param_size
for node in region.nodes:
- self.runtime_mem += calculate_fwd_tmp(node) + \
- calculate_fwd_out(node)
+ self.runtime_mem += calculate_fwd_tmp(node) + calculate_fwd_out(node)
self.fwd_node_mem[node] = self.runtime_mem
self.peak_mem = max(self.runtime_mem, self.peak_mem)
self.total_mem_saving += node.node_info.runtime_fwd_mem - self.runtime_mem
@@ -141,18 +133,15 @@ def _eval_bwd_mem_per_region(self, region: Region):
self.runtime_mem += region.param_size
for node in region.nodes.__reversed__():
-
self.runtime_mem -= calculate_fwd_out(node)
- self.runtime_mem += node.meta['bwd_mem_tmp'] + \
- node.meta['bwd_mem_out']
+ self.runtime_mem += node.meta["bwd_mem_tmp"] + node.meta["bwd_mem_out"]
self.peak_mem = max(self.runtime_mem, self.peak_mem)
# The memory savings of a node may be negative due to parameter prefetch.
self.total_mem_saving += node.node_info.runtime_bwd_mem - self.runtime_mem
self.bwd_node_mem[node] = self.runtime_mem
- self.runtime_mem -= (node.meta['bwd_mem_tmp'] +
- calculate_fwd_tmp(node))
+ self.runtime_mem -= node.meta["bwd_mem_tmp"] + calculate_fwd_tmp(node)
# free bwd_mem_out
self.bwd_node_deps[node] = len(node.all_input_nodes)
@@ -160,12 +149,14 @@ def _eval_bwd_mem_per_region(self, region: Region):
if user_node in self.bwd_node_deps:
self.bwd_node_deps[user_node] -= 1
if self.bwd_node_deps[user_node] <= 0:
- self.runtime_mem -= user_node.meta['bwd_mem_out']
+ self.runtime_mem -= user_node.meta["bwd_mem_out"]
if self.runtime_mem < 0:
- raise ValueError(f"region id: {region.r_id}, node name: {node.name}, "
- f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---"
- f"runtime memory computed less than 0, which is miscalculated!")
+ raise ValueError(
+ f"region id: {region.r_id}, node name: {node.name}, "
+ f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---"
+ f"runtime memory computed less than 0, which is miscalculated!"
+ )
# release parameter and offload gradient in region
if region.r_id == region.shared_rid:
@@ -177,23 +168,16 @@ def _eval_bwd_mem_per_region(self, region: Region):
class AsynTrainingSimulator(TrainingSimulator):
-
- def __init__(self,
- region_list: List[Region],
- comp_power: float,
- link_to_bw: Dict[str, Dict[float, float]]) -> None:
+ def __init__(self, region_list: List[Region], comp_power: float, link_to_bw: Dict[str, Dict[float, float]]) -> None:
super().__init__(region_list, comp_power, link_to_bw)
self.iter_end_time: int = 0
# the last computation execution period
- self.last_comp: ExecutionPeriod = ExecutionPeriod(
- start_time=0, end_time=0)
+ self.last_comp: ExecutionPeriod = ExecutionPeriod(start_time=0, end_time=0)
# the last parameter prefetch execution period
- self.last_h2d: ExecutionPeriod = ExecutionPeriod(
- start_time=0, end_time=0)
+ self.last_h2d: ExecutionPeriod = ExecutionPeriod(start_time=0, end_time=0)
# the last gradient offload execution period
- self.last_d2h: ExecutionPeriod = ExecutionPeriod(
- start_time=0, end_time=0)
+ self.last_d2h: ExecutionPeriod = ExecutionPeriod(start_time=0, end_time=0)
# the forward computation execution period of the region
self.fwd_reg_to_comp: OrderedDict[int, ExecutionPeriod] = OrderedDict()
# the forward parameter prefetch execution period of the region
@@ -204,10 +188,8 @@ def __init__(self,
self.bwd_reg_to_pref: OrderedDict[int, ExecutionPeriod] = OrderedDict()
# the gradient offload execution period of the region
# which is divided into those that are waiting and those that have been released
- self.bwd_reg_to_offl_waiting: OrderedDict[int,
- ExecutionPeriod] = OrderedDict()
- self.bwd_reg_to_offl_freed: OrderedDict[int,
- ExecutionPeriod] = OrderedDict()
+ self.bwd_reg_to_offl_waiting: OrderedDict[int, ExecutionPeriod] = OrderedDict()
+ self.bwd_reg_to_offl_freed: OrderedDict[int, ExecutionPeriod] = OrderedDict()
# the region buffer, which records regions that are offloaded but not released
self.reg_buffer_to_free: List[int] = []
@@ -217,10 +199,8 @@ def __init__(self,
# the region execution flow,
# where fwd_reg_flow[i,j] denotes whether the parameters of j-th region are in the GPU
# when the execution reaches the i-th region.
- self.fwd_reg_flow = torch.zeros(
- (self.region_num, self.region_num)).bool()
- self.bwd_reg_flow = torch.zeros(
- (self.region_num, self.region_num)).bool()
+ self.fwd_reg_flow = torch.zeros((self.region_num, self.region_num)).bool()
+ self.bwd_reg_flow = torch.zeros((self.region_num, self.region_num)).bool()
def execute(self):
"""
@@ -232,7 +212,7 @@ def execute(self):
for reg in self.region_list:
if reg.param_size and reg.r_id < self.region_num - 1:
- for nr in self.region_list[reg.r_id + 1:]:
+ for nr in self.region_list[reg.r_id + 1 :]:
if nr.param_size and requires_upload_p_in_fwd(self.region_list[nr.shared_rid]):
reg.fwd_prefetch_region = nr
break
@@ -249,8 +229,7 @@ def execute(self):
self.runtime_mem -= self.region_list[reg_id].param_size
self.bwd_reg_to_offl_waiting.clear()
- self.iter_end_time = max(
- self.last_comp.end_time, self.last_d2h.end_time)
+ self.iter_end_time = max(self.last_comp.end_time, self.last_d2h.end_time)
def _insert_h2d_exec(self, region: Region, is_fwd: bool = True):
"""
@@ -258,10 +237,8 @@ def _insert_h2d_exec(self, region: Region, is_fwd: bool = True):
"""
pref_start_time = max(self.last_h2d.end_time, self.last_comp.end_time)
- pref_end_time = pref_start_time + \
- 2.0 * self._get_communication_overhead('h2d', region.param_size)
- pref_ep = ExecutionPeriod(
- start_time=pref_start_time, end_time=pref_end_time)
+ pref_end_time = pref_start_time + 2.0 * self._get_communication_overhead("h2d", region.param_size)
+ pref_ep = ExecutionPeriod(start_time=pref_start_time, end_time=pref_end_time)
if is_fwd:
self.fwd_reg_to_pref[region.r_id] = pref_ep
else:
@@ -276,18 +253,16 @@ def _insert_comp_exec(self, region: Region, is_fwd: bool = True):
if is_fwd:
reg_to_comp = self.fwd_reg_to_comp
reg_to_pref = self.fwd_reg_to_pref
- flop_key = 'fwd_flop'
+ flop_key = "fwd_flop"
else:
reg_to_comp = self.bwd_reg_to_comp
reg_to_pref = self.bwd_reg_to_pref
- flop_key = 'bwd_flop'
- comp_start_time = max(self.last_comp.end_time, reg_to_pref.get(
- region.r_id, ExecutionPeriod(0, 0)).end_time)
- comp_end_time = comp_start_time + \
- sum([self._get_computing_overhead(node.meta.get(flop_key, 0))
- for node in region.nodes])
- comp_ep = ExecutionPeriod(
- start_time=comp_start_time, end_time=comp_end_time)
+ flop_key = "bwd_flop"
+ comp_start_time = max(self.last_comp.end_time, reg_to_pref.get(region.r_id, ExecutionPeriod(0, 0)).end_time)
+ comp_end_time = comp_start_time + sum(
+ [self._get_computing_overhead(node.meta.get(flop_key, 0)) for node in region.nodes]
+ )
+ comp_ep = ExecutionPeriod(start_time=comp_start_time, end_time=comp_end_time)
reg_to_comp[region.r_id] = comp_ep
self.last_comp = comp_ep
@@ -297,10 +272,8 @@ def _insert_d2h_exec(self, region: Region):
"""
offl_start_time = max(self.last_d2h.end_time, self.last_comp.end_time)
- offl_end_time = offl_start_time + \
- self._get_communication_overhead('d2h', region.param_size)
- offl_ep = ExecutionPeriod(
- start_time=offl_start_time, end_time=offl_end_time)
+ offl_end_time = offl_start_time + self._get_communication_overhead("d2h", region.param_size)
+ offl_ep = ExecutionPeriod(start_time=offl_start_time, end_time=offl_end_time)
self.bwd_reg_to_offl_waiting[region.r_id] = offl_ep
self.last_d2h = offl_ep
@@ -332,20 +305,17 @@ def _eval_fwd_mem_per_region(self, region: Region):
self.fwd_reg_flow[region.r_id, region.r_id] = True
else:
self.fwd_reg_flow[region.r_id] = self.fwd_reg_flow[region.r_id - 1]
- self.fwd_reg_flow[region.r_id,
- self.reg_buffer_to_free] = False
+ self.fwd_reg_flow[region.r_id, self.reg_buffer_to_free] = False
self.reg_buffer_to_free.clear()
# prefetch parameters of the next region
fwd_prefetch_region = region.fwd_prefetch_region
if fwd_prefetch_region and requires_upload_p_in_fwd(self.region_list[fwd_prefetch_region.shared_rid]):
self.runtime_mem += fwd_prefetch_region.param_size
- self.fwd_reg_flow[region.r_id,
- fwd_prefetch_region.r_id] = True
+ self.fwd_reg_flow[region.r_id, fwd_prefetch_region.r_id] = True
for node in region.nodes:
- self.runtime_mem += calculate_fwd_tmp(node) + \
- calculate_fwd_out(node)
+ self.runtime_mem += calculate_fwd_tmp(node) + calculate_fwd_out(node)
self.peak_mem = max(self.runtime_mem, self.peak_mem)
self.total_mem_saving += node.node_info.runtime_fwd_mem - self.runtime_mem
@@ -354,8 +324,7 @@ def _eval_fwd_mem_per_region(self, region: Region):
if region.need_offload:
self.runtime_mem -= region.param_size
- assert len(
- self.reg_buffer_to_free) <= 1, f'{len(self.reg_buffer_to_free)}'
+ assert len(self.reg_buffer_to_free) <= 1, f"{len(self.reg_buffer_to_free)}"
self.reg_buffer_to_free.append(region.r_id)
def _eval_bwd_cost_per_region(self, region: Region):
@@ -398,8 +367,7 @@ def _eval_bwd_mem_per_region(self, region: Region):
self.bwd_reg_flow[region.r_id] = self.bwd_reg_flow[region.r_id + 1]
else:
self.bwd_reg_flow[region.r_id] = self.fwd_reg_flow[-1]
- self.bwd_reg_flow[region.r_id,
- self.reg_buffer_to_free] = False
+ self.bwd_reg_flow[region.r_id, self.reg_buffer_to_free] = False
# free gradients in the buffer
while len(self.reg_buffer_to_free):
@@ -415,8 +383,7 @@ def _eval_bwd_mem_per_region(self, region: Region):
bwd_prefetch_region = region.bwd_prefetch_region
if bwd_prefetch_region:
self.runtime_mem += bwd_prefetch_region.param_size
- self.bwd_reg_flow[region.r_id,
- bwd_prefetch_region.r_id] = True
+ self.bwd_reg_flow[region.r_id, bwd_prefetch_region.r_id] = True
# add the gradient of the parameter
if region.r_id < region.shared_rid:
@@ -426,10 +393,8 @@ def _eval_bwd_mem_per_region(self, region: Region):
self.runtime_mem += region.param_size
for node in region.nodes.__reversed__():
-
self.runtime_mem -= calculate_fwd_out(node)
- self.runtime_mem += node.meta['bwd_mem_tmp'] + \
- node.meta['bwd_mem_out']
+ self.runtime_mem += node.meta["bwd_mem_tmp"] + node.meta["bwd_mem_out"]
self.peak_mem = max(self.runtime_mem, self.peak_mem)
# The memory savings of a node may be negative due to parameter prefetch.
@@ -437,8 +402,7 @@ def _eval_bwd_mem_per_region(self, region: Region):
self.bwd_node_mem[node] = self.runtime_mem
- self.runtime_mem -= (node.meta['bwd_mem_tmp'] +
- calculate_fwd_tmp(node))
+ self.runtime_mem -= node.meta["bwd_mem_tmp"] + calculate_fwd_tmp(node)
# free bwd_mem_out
self.bwd_node_deps[node] = len(node.all_input_nodes)
@@ -446,12 +410,14 @@ def _eval_bwd_mem_per_region(self, region: Region):
if user_node in self.bwd_node_deps:
self.bwd_node_deps[user_node] -= 1
if self.bwd_node_deps[user_node] <= 0:
- self.runtime_mem -= user_node.meta['bwd_mem_out']
+ self.runtime_mem -= user_node.meta["bwd_mem_out"]
if self.runtime_mem < 0:
- raise ValueError(f"region id: {region.r_id}, node name: {node.name}, "
- f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---"
- f"runtime memory computed less than 0, which is miscalculated!")
+ raise ValueError(
+ f"region id: {region.r_id}, node name: {node.name}, "
+ f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---"
+ f"runtime memory computed less than 0, which is miscalculated!"
+ )
# release parameters of the region
if requires_release_p_in_bwd(self.region_list[region.shared_rid]):
diff --git a/colossalai/auto_parallel/offload/util.py b/colossalai/auto_parallel/offload/util.py
index 6b010512cc9c..cb65da79c5a2 100644
--- a/colossalai/auto_parallel/offload/util.py
+++ b/colossalai/auto_parallel/offload/util.py
@@ -35,7 +35,6 @@ class NvDevicePower:
class GlobalRuntimeInfo(metaclass=SingletonMeta):
-
def __init__(self):
self.h2d_stream = torch.cuda.Stream()
self.d2h_stream = torch.cuda.Stream()
@@ -50,21 +49,18 @@ def compute_act_peak_mem(region_list: List[Region]) -> float:
# forward
for region in region_list:
for node in region.nodes:
- runtime_mem = runtime_mem + \
- calculate_fwd_tmp(node) + calculate_fwd_out(node)
+ runtime_mem = runtime_mem + calculate_fwd_tmp(node) + calculate_fwd_out(node)
act_peak_mem = max(runtime_mem, act_peak_mem)
# backward
bwd_deps = {}
for region in region_list.__reversed__():
for node in region.nodes.__reversed__():
runtime_mem -= calculate_fwd_out(node)
- runtime_mem = runtime_mem + \
- node.meta['bwd_mem_tmp'] + node.meta['bwd_mem_out']
+ runtime_mem = runtime_mem + node.meta["bwd_mem_tmp"] + node.meta["bwd_mem_out"]
act_peak_mem = max(runtime_mem, act_peak_mem)
- runtime_mem = runtime_mem - \
- node.meta['bwd_mem_tmp'] - calculate_fwd_tmp(node)
+ runtime_mem = runtime_mem - node.meta["bwd_mem_tmp"] - calculate_fwd_tmp(node)
# free bwd_mem_out
bwd_deps[node] = len(node.all_input_nodes)
@@ -72,7 +68,7 @@ def compute_act_peak_mem(region_list: List[Region]) -> float:
if user_node in bwd_deps:
bwd_deps[user_node] -= 1
if bwd_deps[user_node] <= 0:
- runtime_mem -= user_node.meta['bwd_mem_out']
+ runtime_mem -= user_node.meta["bwd_mem_out"]
return act_peak_mem
@@ -86,13 +82,15 @@ def compute_total_param_mem(region_list: List[Region]) -> float:
def requires_upload_p_in_fwd(shared_reg: Region):
- return (shared_reg.r_id >= shared_reg.shared_rid) or (shared_reg.r_id < shared_reg.shared_rid
- and shared_reg.need_offload)
+ return (shared_reg.r_id >= shared_reg.shared_rid) or (
+ shared_reg.r_id < shared_reg.shared_rid and shared_reg.need_offload
+ )
def requires_release_p_in_bwd(shared_reg: Region):
- return (shared_reg.r_id >= shared_reg.shared_rid) or (shared_reg.r_id < shared_reg.shared_rid
- and shared_reg.need_offload)
+ return (shared_reg.r_id >= shared_reg.shared_rid) or (
+ shared_reg.r_id < shared_reg.shared_rid and shared_reg.need_offload
+ )
def requires_offload_g_in_bwd(region: Region):
diff --git a/colossalai/auto_parallel/passes/comm_metainfo_pass.py b/colossalai/auto_parallel/passes/comm_metainfo_pass.py
index ffda58e0689f..ba290ee839d8 100644
--- a/colossalai/auto_parallel/passes/comm_metainfo_pass.py
+++ b/colossalai/auto_parallel/passes/comm_metainfo_pass.py
@@ -14,18 +14,20 @@
shape_consistency_manager = ShapeConsistencyManager()
-def _construct_shard_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
- target_sharding_spec: ShardingSpec) -> ShardMetaInfo:
+def _construct_shard_meta_info(
+ node: Node, origin_sharding_spec: ShardingSpec, target_sharding_spec: ShardingSpec
+) -> ShardMetaInfo:
# get comm_action_sequence and total_cost from shape_consistency_manager
_, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(
- origin_sharding_spec, target_sharding_spec)
+ origin_sharding_spec, target_sharding_spec
+ )
meta_info = ShardMetaInfo()
# NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel
# get mem cost for ShardMetaInfo
mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence)
# extract user that has _meta_data and extract element length
- input_node = next(n for n in node._input_nodes if hasattr(n, '_meta_data'))
+ input_node = next(n for n in node._input_nodes if hasattr(n, "_meta_data"))
element_length = input_node._meta_data.element_size()
mem_cost.fwd.activation *= element_length
@@ -37,9 +39,11 @@ def _construct_shard_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
meta_info.memory_cost = mem_cost
# get computation cost for ShardMetaInfo
- meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length,
- total_cost['backward'] * element_length,
- total_cost['total'] * element_length)
+ meta_info.compute_cost = TrainCycleItem(
+ total_cost["forward"] * element_length,
+ total_cost["backward"] * element_length,
+ total_cost["total"] * element_length,
+ )
# get tensor shape for ShardMetaInfo
origin_sharding_spec: ShardingSpec
@@ -47,9 +51,9 @@ def _construct_shard_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
input_shape = origin_sharding_spec.get_sharded_shape_per_device()
output_shape = target_sharding_spec.get_sharded_shape_per_device()
- meta_info.fwd_in = [torch.rand(input_shape, device='meta')]
+ meta_info.fwd_in = [torch.rand(input_shape, device="meta")]
meta_info.fwd_buffer = []
- meta_info.fwd_out = [torch.rand(output_shape, device='meta')]
+ meta_info.fwd_out = [torch.rand(output_shape, device="meta")]
return meta_info
@@ -62,8 +66,10 @@ def _runtime_apply_meta_info(node: Node, origin_spec_dict, sharding_spec_dict) -
# extract node index and user node index
args = node.args
node_index, user_node_index = args[3], args[4]
- origin_sharding_spec, target_sharding_spec = origin_spec_dict[node_index], sharding_spec_dict[node_index][
- user_node_index]
+ origin_sharding_spec, target_sharding_spec = (
+ origin_spec_dict[node_index],
+ sharding_spec_dict[node_index][user_node_index],
+ )
return _construct_shard_meta_info(node, origin_sharding_spec, target_sharding_spec)
@@ -77,37 +83,42 @@ def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> S
# this case is for all_reduce, there will be no memory cost
meta_info = ShardMetaInfo()
meta_info.memory_cost = TrainCycleItem(MemoryCost(), MemoryCost(), MemoryCost)
- output_node = next(n for n in node.users if hasattr(n, '_meta_data'))
+ output_node = next(n for n in node.users if hasattr(n, "_meta_data"))
element_length = output_node._meta_data.element_size()
total_cost = comm_action.comm_spec.get_comm_cost()
- meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length,
- total_cost['backward'] * element_length,
- total_cost['total'] * element_length)
+ meta_info.compute_cost = TrainCycleItem(
+ total_cost["forward"] * element_length,
+ total_cost["backward"] * element_length,
+ total_cost["total"] * element_length,
+ )
input_shape = output_shape = comm_action.comm_spec.sharding_spec.get_sharded_shape_per_device()
- meta_info.fwd_in = [torch.rand(input_shape, device='meta')]
+ meta_info.fwd_in = [torch.rand(input_shape, device="meta")]
meta_info.fwd_buffer = []
- meta_info.fwd_out = [torch.rand(output_shape, device='meta')]
+ meta_info.fwd_out = [torch.rand(output_shape, device="meta")]
else:
# this case will be handled by shape consistency manager
- origin_sharding_spec, target_sharding_spec = comm_action.comm_spec['src_spec'], comm_action.comm_spec[
- 'tgt_spec']
+ origin_sharding_spec, target_sharding_spec = (
+ comm_action.comm_spec["src_spec"],
+ comm_action.comm_spec["tgt_spec"],
+ )
meta_info = _construct_shard_meta_info(node, origin_sharding_spec, target_sharding_spec)
return meta_info
-def comm_metainfo_pass(gm: GraphModule, sharding_spec_dict: Dict, origin_spec_dict: Dict,
- comm_actions_dict: Dict) -> GraphModule:
+def comm_metainfo_pass(
+ gm: GraphModule, sharding_spec_dict: Dict, origin_spec_dict: Dict, comm_actions_dict: Dict
+) -> GraphModule:
"""
The method manages all the metainfo of the communication node (run_time_apply, runtime_comm_spec_apply) in the graph.
"""
for node in gm.graph.nodes:
if node.target == runtime_apply:
- setattr(node, 'best_strategy_info', _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict))
+ setattr(node, "best_strategy_info", _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict))
elif node.target == runtime_comm_spec_apply:
- setattr(node, 'best_strategy_info', _runtime_comm_spec_apply_meta_info(node, comm_actions_dict))
+ setattr(node, "best_strategy_info", _runtime_comm_spec_apply_meta_info(node, comm_actions_dict))
else:
pass
return gm
diff --git a/colossalai/auto_parallel/passes/meta_info_prop.py b/colossalai/auto_parallel/passes/meta_info_prop.py
index 0673b767de7b..9b000549de6c 100644
--- a/colossalai/auto_parallel/passes/meta_info_prop.py
+++ b/colossalai/auto_parallel/passes/meta_info_prop.py
@@ -21,16 +21,15 @@ def _normalize_tuple(x):
@compatibility(is_backward_compatible=False)
class MetaInfoProp:
-
def __init__(self, module: GraphModule) -> None:
self.module = module
self.func_dict = {
- 'placeholder': self.placeholder_handler,
- 'get_attr': self.get_attr_handler,
- 'output': self.output_handler,
- 'call_function': self.node_handler,
- 'call_module': self.node_handler,
- 'call_method': self.node_handler,
+ "placeholder": self.placeholder_handler,
+ "get_attr": self.get_attr_handler,
+ "output": self.output_handler,
+ "call_function": self.node_handler,
+ "call_module": self.node_handler,
+ "call_method": self.node_handler,
}
def _set_data_ptr(self, x):
@@ -46,7 +45,7 @@ def _is_inplace(self, node: Node):
"""
Check if the node is inplace operation.
"""
- if node.op == 'call_module':
+ if node.op == "call_module":
return node.graph.owning_module.get_submodule(node.target).__class__ in OUTPUT_SAVED_MOD
elif node.op == "call_function":
return node.target in OUTPUT_SAVED_OPS
@@ -66,7 +65,7 @@ def placeholder_handler(self, node: Node) -> None:
Handle the placeholder node.
"""
graph_info = GraphInfo()
- out = _normalize_tuple(getattr(node, '_meta_data', None))
+ out = _normalize_tuple(getattr(node, "_meta_data", None))
graph_info.fwd_out = list(out) if out[0] is not None else []
node.meta = {**asdict(graph_info)}
@@ -96,7 +95,7 @@ def node_handler(self, node: Node) -> None:
"""
Handle other kind of nodes
"""
- assert hasattr(node, 'best_strategy_info'), f"Cannot find best_strategy_info in node {node}, {node.op}"
+ assert hasattr(node, "best_strategy_info"), f"Cannot find best_strategy_info in node {node}, {node.op}"
graph_info = GraphInfo()
meta_info = node.best_strategy_info
meta_info: ShardMetaInfo
@@ -126,7 +125,8 @@ def node_handler(self, node: Node) -> None:
for tensor in par.meta.get("fwd_out", []):
tensor: torch.Tensor
target_input_tensor = next(
- (x for x in input_tensors if not x.data_ptr() and x.shape == tensor.shape), None)
+ (x for x in input_tensors if not x.data_ptr() and x.shape == tensor.shape), None
+ )
if target_input_tensor is not None:
target_input_tensor.data_ptr = tensor.data_ptr
diff --git a/colossalai/auto_parallel/passes/runtime_apply_pass.py b/colossalai/auto_parallel/passes/runtime_apply_pass.py
index 2049a06187d2..27afe72c0db8 100644
--- a/colossalai/auto_parallel/passes/runtime_apply_pass.py
+++ b/colossalai/auto_parallel/passes/runtime_apply_pass.py
@@ -1,18 +1,10 @@
-from copy import deepcopy
from typing import Dict, List
import torch
from torch.fx.node import Node
from colossalai._analyzer.fx.node_util import MetaInfo
-from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
- CommAction,
- CommType,
- OperationData,
- OperationDataType,
- TrainCycleItem,
-)
-from colossalai.device.device_mesh import DeviceMesh
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommType, OperationDataType
from colossalai.tensor.comm_spec import CommSpec
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
@@ -30,19 +22,22 @@ def runtime_apply(node: Node, origin_dict: Dict, input_dict: Dict, node_index: i
return shape_consistency_manager.apply_for_autoparallel_runtime(node, origin_sharding_spec, target_sharding_spec)
-def runtime_apply_for_iterable_object(node: Node, origin_dict: Dict, input_dict: Dict, node_index: int,
- user_node_index: int):
+def runtime_apply_for_iterable_object(
+ node: Node, origin_dict: Dict, input_dict: Dict, node_index: int, user_node_index: int
+):
"""
This method will be invoked during runtime to do the shape consistency, which makes sure the activations in type of tuple or list
is converted into the user node expected form.
"""
rst = []
- for index, (origin_sharding_spec,
- target_sharding_spec) in enumerate(zip(origin_dict[node_index],
- input_dict[node_index][user_node_index])):
+ for index, (origin_sharding_spec, target_sharding_spec) in enumerate(
+ zip(origin_dict[node_index], input_dict[node_index][user_node_index])
+ ):
rst.append(
- shape_consistency_manager.apply_for_autoparallel_runtime(node[index], origin_sharding_spec,
- target_sharding_spec))
+ shape_consistency_manager.apply_for_autoparallel_runtime(
+ node[index], origin_sharding_spec, target_sharding_spec
+ )
+ )
rst = type(node)(rst)
return rst
@@ -55,8 +50,8 @@ def runtime_comm_spec_apply(tensor: torch.Tensor, comm_actions_dict: Dict, node_
if isinstance(comm_action.comm_spec, CommSpec):
rst = comm_action.comm_spec.covert_spec_to_action(tensor)
else:
- origin_sharding_spec = comm_action.comm_spec['src_spec']
- tgt_sharding_spec = comm_action.comm_spec['tgt_spec']
+ origin_sharding_spec = comm_action.comm_spec["src_spec"]
+ tgt_sharding_spec = comm_action.comm_spec["tgt_spec"]
rst = shape_consistency_manager.apply_for_autoparallel_runtime(tensor, origin_sharding_spec, tgt_sharding_spec)
return rst
@@ -70,16 +65,16 @@ def _preprocess_graph(nodes: List[Node]):
node_to_index_dict = {}
index = 0
for node in nodes:
- if node.target == 'sharding_spec_convert_dict':
+ if node.target == "sharding_spec_convert_dict":
input_dict_node = node
continue
- if node.target == 'origin_node_sharding_spec_dict':
+ if node.target == "origin_node_sharding_spec_dict":
origin_dict_node = node
continue
- if node.target == 'comm_actions_dict':
+ if node.target == "comm_actions_dict":
comm_actions_dict_node = node
continue
- if not hasattr(node, 'best_strategy'):
+ if not hasattr(node, "best_strategy"):
continue
node_to_index_dict[node] = index
index += 1
@@ -97,41 +92,46 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
input_dict_node, origin_dict_node, _, node_to_index_dict = _preprocess_graph(nodes)
for node in nodes:
- if not hasattr(node, 'best_strategy') or node.op == 'output':
+ if not hasattr(node, "best_strategy") or node.op == "output":
continue
for user_node_index, user_node in enumerate(node.strategies_vector.successor_nodes):
if isinstance(node.sharding_spec, (list, tuple)):
assert isinstance(
- node.target_sharding_specs,
- (list,
- tuple)), 'target sharding specs should be tuple or list when node.sharding_spec is tuple or list'
+ node.target_sharding_specs, (list, tuple)
+ ), "target sharding specs should be tuple or list when node.sharding_spec is tuple or list"
total_difference = 0
- for sharding_spec, target_sharding_spec in zip(node.sharding_spec,
- node.target_sharding_specs[user_node_index]):
+ for sharding_spec, target_sharding_spec in zip(
+ node.sharding_spec, node.target_sharding_specs[user_node_index]
+ ):
total_difference += sharding_spec.sharding_sequence_difference(target_sharding_spec)
if total_difference == 0:
continue
with mod_graph.inserting_before(user_node):
- shape_consistency_node = mod_graph.create_node('call_function',
- runtime_apply_for_iterable_object,
- args=(node, origin_dict_node, input_dict_node,
- node_to_index_dict[node], user_node_index))
+ shape_consistency_node = mod_graph.create_node(
+ "call_function",
+ runtime_apply_for_iterable_object,
+ args=(node, origin_dict_node, input_dict_node, node_to_index_dict[node], user_node_index),
+ )
else:
- assert isinstance(node.sharding_spec,
- ShardingSpec), 'node.sharding_spec should be type of ShardingSpec, tuple or list.'
+ assert isinstance(
+ node.sharding_spec, ShardingSpec
+ ), "node.sharding_spec should be type of ShardingSpec, tuple or list."
if node.sharding_spec.sharding_sequence_difference(node.target_sharding_specs[user_node_index]) == 0:
continue
with mod_graph.inserting_before(user_node):
- shape_consistency_node = mod_graph.create_node('call_function',
- runtime_apply,
- args=(node, origin_dict_node, input_dict_node,
- node_to_index_dict[node], user_node_index))
- if hasattr(user_node.meta['info'], 'activation_checkpoint'):
- MetaInfo(shape_consistency_node,
- mod_dir=user_node.meta['info'].mod_dir,
- activation_checkpoint=tuple(user_node.meta['info'].activation_checkpoint))
+ shape_consistency_node = mod_graph.create_node(
+ "call_function",
+ runtime_apply,
+ args=(node, origin_dict_node, input_dict_node, node_to_index_dict[node], user_node_index),
+ )
+ if hasattr(user_node.meta["info"], "activation_checkpoint"):
+ MetaInfo(
+ shape_consistency_node,
+ mod_dir=user_node.meta["info"].mod_dir,
+ activation_checkpoint=tuple(user_node.meta["info"].activation_checkpoint),
+ )
new_args = list(user_node.args)
new_kwargs = dict(user_node.kwargs)
# the origin node may be a positional argument or key word argument of user node
@@ -158,12 +158,11 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
_, _, comm_actions_dict_node, node_to_index_dict = _preprocess_graph(nodes)
for node in nodes:
- if not hasattr(node, 'best_strategy') or node.op == 'output':
+ if not hasattr(node, "best_strategy") or node.op == "output":
continue
comm_actions = node.best_strategy.communication_actions
for op_data, comm_action in comm_actions.items():
-
if comm_action.comm_type == CommType.HOOK:
continue
if comm_action.comm_type == CommType.BEFORE:
@@ -174,10 +173,11 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
else:
comm_object = node.args[comm_action.arg_index]
with mod_graph.inserting_before(node):
- comm_spec_apply_node = mod_graph.create_node('call_function',
- runtime_comm_spec_apply,
- args=(comm_object, comm_actions_dict_node,
- node_to_index_dict[node], op_data.name))
+ comm_spec_apply_node = mod_graph.create_node(
+ "call_function",
+ runtime_comm_spec_apply,
+ args=(comm_object, comm_actions_dict_node, node_to_index_dict[node], op_data.name),
+ )
# the origin node may be a positional argument or key word argument of user node
if comm_action.key_for_kwarg is not None:
# substitute the origin node with comm_spec_apply_node
@@ -192,10 +192,11 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
elif comm_action.comm_type == CommType.AFTER:
with mod_graph.inserting_after(node):
- comm_spec_apply_node = mod_graph.create_node('call_function',
- runtime_comm_spec_apply,
- args=(node, comm_actions_dict_node,
- node_to_index_dict[node], op_data.name))
+ comm_spec_apply_node = mod_graph.create_node(
+ "call_function",
+ runtime_comm_spec_apply,
+ args=(node, comm_actions_dict_node, node_to_index_dict[node], op_data.name),
+ )
user_list = list(node.users.keys())
for user in user_list:
if user == comm_spec_apply_node:
@@ -211,10 +212,12 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
# substitute the origin node with comm_spec_apply_node
new_kwargs[str(node)] = comm_spec_apply_node
user.kwargs = new_kwargs
- if hasattr(node.meta['info'], 'activation_checkpoint'):
- MetaInfo(comm_spec_apply_node,
- mod_dir=node.meta['info'].mod_dir,
- activation_checkpoint=tuple(node.meta['info'].activation_checkpoint))
+ if hasattr(node.meta["info"], "activation_checkpoint"):
+ MetaInfo(
+ comm_spec_apply_node,
+ mod_dir=node.meta["info"].mod_dir,
+ activation_checkpoint=tuple(node.meta["info"].activation_checkpoint),
+ )
return gm
@@ -227,21 +230,21 @@ def _act_annotation_pass(gm: torch.fx.GraphModule):
nodes = tuple(mod_graph.nodes)
for node in nodes:
- if not hasattr(node.meta, 'activation_checkpoint'):
- from .runtime_preparation_pass import size_processing
+ if not hasattr(node.meta, "activation_checkpoint"):
+ pass
user_act_annotation = -1
input_act_annotation = -1
for user_node in node.users.keys():
- if 'activation_checkpoint' in user_node.meta:
- user_act_annotation = user_node.meta['activation_checkpoint']
+ if "activation_checkpoint" in user_node.meta:
+ user_act_annotation = user_node.meta["activation_checkpoint"]
break
for input_node in node._input_nodes.keys():
- if 'activation_checkpoint' in input_node.meta:
- input_act_annotation = input_node.meta['activation_checkpoint']
+ if "activation_checkpoint" in input_node.meta:
+ input_act_annotation = input_node.meta["activation_checkpoint"]
break
if user_act_annotation == input_act_annotation and user_act_annotation != -1:
- node.meta['activation_checkpoint'] = user_act_annotation
+ node.meta["activation_checkpoint"] = user_act_annotation
return gm
diff --git a/colossalai/auto_parallel/passes/runtime_preparation_pass.py b/colossalai/auto_parallel/passes/runtime_preparation_pass.py
index 0ed0742ee57e..65c3d8e0cbeb 100644
--- a/colossalai/auto_parallel/passes/runtime_preparation_pass.py
+++ b/colossalai/auto_parallel/passes/runtime_preparation_pass.py
@@ -1,19 +1,12 @@
import operator
-from copy import deepcopy
from typing import Dict, List, Union
import torch
-from torch.fx import symbolic_trace
from torch.fx.node import Node
from colossalai._analyzer.fx.node_util import MetaInfo
from colossalai.auto_parallel.tensor_shard.constants import RESHAPE_FUNC_OP
-from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
- CommAction,
- CommType,
- OperationDataType,
- ShardingStrategy,
-)
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommType, OperationDataType
from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.comm_spec import _all_reduce
@@ -25,11 +18,13 @@
shape_consistency_manager = ShapeConsistencyManager()
-def size_processing(size: Union[int, torch.Size],
- dim_partition_dict: Dict[int, List[int]],
- device_mesh_info: Dict[int, int],
- target_dim: int = None,
- node_name: str = None):
+def size_processing(
+ size: Union[int, torch.Size],
+ dim_partition_dict: Dict[int, List[int]],
+ device_mesh_info: Dict[int, int],
+ target_dim: int = None,
+ node_name: str = None,
+):
"""
This method will be invoked during runtime to convert size node value depending on distributed information.
"""
@@ -54,8 +49,9 @@ def size_processing(size: Union[int, torch.Size],
return size
-def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int],
- strategies_constructor: StrategiesConstructor):
+def solution_annotation_pass(
+ gm: torch.fx.GraphModule, solution: List[int], strategies_constructor: StrategiesConstructor
+):
"""
This method is used to stick the solution strategy to the nodes and add the information
required in runtime into graph as placeholder nodes.
@@ -70,14 +66,15 @@ def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int],
for node_index, (node, strategy_index) in enumerate(zip(nodes, solution)):
strategies_vector = node.strategies_vector
# stick the solution strategy to the corresponding node
- setattr(node, 'best_strategy', strategies_vector[strategy_index])
- setattr(node, 'sharding_spec', strategies_vector[strategy_index].get_sharding_spec_by_name(str(node)))
+ setattr(node, "best_strategy", strategies_vector[strategy_index])
+ setattr(node, "sharding_spec", strategies_vector[strategy_index].get_sharding_spec_by_name(str(node)))
origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].get_sharding_spec_by_name(
- str(node))
+ str(node)
+ )
# attach the corresponding metainfo if node has the attribute `strategies_info`
- if hasattr(node, 'strategies_info'):
- setattr(node, 'best_strategy_info', node.strategies_info[strategy_index])
+ if hasattr(node, "strategies_info"):
+ setattr(node, "best_strategy_info", node.strategies_info[strategy_index])
# the dict to get input sharding specs of user node
sharding_spec_convert_dict = {}
@@ -92,15 +89,15 @@ def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int],
target_sharding_spec = user_node.best_strategy.get_sharding_spec_by_name(str(node.name))
target_sharding_specs.append(target_sharding_spec)
sharding_spec_convert_dict[index] = target_sharding_specs
- setattr(node, 'target_sharding_specs', target_sharding_specs)
+ setattr(node, "target_sharding_specs", target_sharding_specs)
# the get_attr node strategy is kind of pending strategy, which means we will change it
# to the same strategy of the user node.
- if node.op == 'get_attr':
- assert len(target_sharding_specs) == 1, f'sharing weight is not supported in current version.'
+ if node.op == "get_attr":
+ assert len(target_sharding_specs) == 1, f"sharing weight is not supported in current version."
target_node = node.strategies_vector.successor_nodes[0]
node_name = str(node)
- if target_node.op == 'call_function' and target_node.target in RESHAPE_FUNC_OP:
+ if target_node.op == "call_function" and target_node.target in RESHAPE_FUNC_OP:
node_name = str(target_node)
target_node = target_node.strategies_vector.successor_nodes[0]
user_strategy = target_node.best_strategy
@@ -122,11 +119,11 @@ def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int],
# add above dicts into graph
for node in nodes:
- if node.op != 'placeholder':
+ if node.op != "placeholder":
with mod_graph.inserting_before(node):
- input_specs_node = mod_graph.create_node('placeholder', target='sharding_spec_convert_dict')
- origin_specs_node = mod_graph.create_node('placeholder', target='origin_node_sharding_spec_dict')
- comm_actions_dict_node = mod_graph.create_node('placeholder', target='comm_actions_dict')
+ input_specs_node = mod_graph.create_node("placeholder", target="sharding_spec_convert_dict")
+ origin_specs_node = mod_graph.create_node("placeholder", target="origin_node_sharding_spec_dict")
+ comm_actions_dict_node = mod_graph.create_node("placeholder", target="comm_actions_dict")
break
return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict
@@ -148,7 +145,7 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
device_mesh_info[dim] = dim_size
def _extract_target_dim(node):
- '''
+ """
A helper function to extract the target dimension from size node.
There are two usages of torch.Tensor.size:
1. tensor.size()
@@ -156,7 +153,7 @@ def _extract_target_dim(node):
If a target_dim is assigned, then the output will be in type of int, instead of torch.Size.
Otherwise, the output will be in type of torch.Size and this function will return None.
- '''
+ """
target_dim = None
if len(node.args) > 1:
target_dim = node.args[1]
@@ -165,19 +162,21 @@ def _extract_target_dim(node):
return target_dim
def _post_processing(node, size_processing_node):
- '''
+ """
This function is used to process the dependency between the size node and its users after
inserting the size_process_node.
- '''
+ """
# store original node and processing node pair in node_pairs dictionary
# It will be used to replace the original node with processing node in slice object
node_pairs[node] = size_processing_node
size_processing_node._meta_data = node._meta_data
- if hasattr(node.meta['info'], 'activation_checkpoint'):
- MetaInfo(size_processing_node,
- mod_dir=node.meta['info'].mod_dir,
- activation_checkpoint=tuple(node.meta['info'].activation_checkpoint))
+ if hasattr(node.meta["info"], "activation_checkpoint"):
+ MetaInfo(
+ size_processing_node,
+ mod_dir=node.meta["info"].mod_dir,
+ activation_checkpoint=tuple(node.meta["info"].activation_checkpoint),
+ )
user_list = list(node.users.keys())
for user in user_list:
@@ -196,10 +195,10 @@ def _post_processing(node, size_processing_node):
user.kwargs = new_kwargs
def _update_slice_object_args(slice_object):
- '''
+ """
This function is used to update the slice object argument list.
If the slice object contains the Node argument, then the size node will be replaced with
- '''
+ """
if isinstance(slice_object, slice):
start = slice_object.start
stop = slice_object.stop
@@ -220,8 +219,7 @@ def _update_slice_object_args(slice_object):
raise RuntimeError(f"Unsupported slice object type: {type(slice_object)}")
for node in nodes:
-
- if node.op == 'call_method' and node.target == 'size':
+ if node.op == "call_method" and node.target == "size":
# extract useful information from size node
# dim_partition_dict will instruct the size value on which
# dimension should be enlarged.
@@ -232,14 +230,14 @@ def _update_slice_object_args(slice_object):
# insert size_processing node
with mod_graph.inserting_after(node):
- size_processing_node = mod_graph.create_node('call_function',
- size_processing,
- args=(node, dim_partition_dict, device_mesh_info,
- target_dim, node.name))
+ size_processing_node = mod_graph.create_node(
+ "call_function",
+ size_processing,
+ args=(node, dim_partition_dict, device_mesh_info, target_dim, node.name),
+ )
_post_processing(node, size_processing_node)
- if node.op == 'call_function' and node.target == operator.getitem:
-
+ if node.op == "call_function" and node.target == operator.getitem:
getitem_index = node.args[1]
# slice object is quite special in torch.fx graph,
# On one side, we treat slice object same as type of int,
@@ -287,18 +285,19 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
nodes = tuple(mod_graph.nodes)
def _extract_info_from_sharding_spec(sharding_spec):
- '''
+ """
This function is used to extract the dim_partition_dict and device_mesh from
sharding spec instance or a list of sharding spec.
- '''
+ """
if isinstance(sharding_spec, ShardingSpec):
dim_partition_dict = sharding_spec.dim_partition_dict
device_mesh = sharding_spec.device_mesh
return dim_partition_dict, device_mesh
if sharding_spec is None:
return None, None
- assert isinstance(sharding_spec,
- (tuple, list)), 'sharding_spec should be type of ShardingSpec, tuple, list or None'
+ assert isinstance(
+ sharding_spec, (tuple, list)
+ ), "sharding_spec should be type of ShardingSpec, tuple, list or None"
device_mesh = sharding_spec[0].device_mesh
dim_partition_dict = []
@@ -322,8 +321,9 @@ def _process_node_arguments(node):
else:
new_args.append(arg)
else:
- assert isinstance(arg,
- (int, tuple, list)), 'The argument in view node should be either type of Node or int.'
+ assert isinstance(
+ arg, (int, tuple, list)
+ ), "The argument in view node should be either type of Node or int."
if isinstance(arg, (tuple, list)):
new_args.extend(arg)
else:
@@ -332,7 +332,7 @@ def _process_node_arguments(node):
def _scale_args_adapt_sharding_spec(dim_partition_dict, device_mesh, node):
new_args = _process_node_arguments(node)
- if node.op == 'call_method':
+ if node.op == "call_method":
args_to_process = list(new_args[1:])
else:
args_to_process = list(new_args)
@@ -350,7 +350,7 @@ def _scale_args_adapt_sharding_spec(dim_partition_dict, device_mesh, node):
args_to_process = tuple(args_to_process)
- if node.op == 'call_method':
+ if node.op == "call_method":
new_args = (new_args[0],) + args_to_process
else:
new_args = args_to_process
@@ -358,9 +358,9 @@ def _scale_args_adapt_sharding_spec(dim_partition_dict, device_mesh, node):
node.args = new_args
def _filter_node_with_shape_args(node):
- if node.op == 'call_method':
+ if node.op == "call_method":
target = getattr(node.args[0]._meta_data.__class__, node.target)
- elif node.op == 'call_function':
+ elif node.op == "call_function":
target = node.target
else:
target = None
@@ -371,7 +371,7 @@ def _filter_node_with_shape_args(node):
for node in nodes:
# skip the placeholder node added in _solution_annotation pass
- if not hasattr(node, 'sharding_spec'):
+ if not hasattr(node, "sharding_spec"):
continue
output_dim_partition_dict, device_mesh = _extract_info_from_sharding_spec(node.sharding_spec)
@@ -392,15 +392,21 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
reduction_stream = torch.cuda.Stream()
def _add_hook_for_grad_communication(node, param, name=None):
-
comm_actions = node.best_strategy.communication_actions
def _filter_param_to_hook(node, op_data, comm_action, name):
-
- if node.op == 'call_module' and op_data.type == OperationDataType.PARAM and op_data.name == name and comm_action.comm_type == CommType.HOOK:
+ if (
+ node.op == "call_module"
+ and op_data.type == OperationDataType.PARAM
+ and op_data.name == name
+ and comm_action.comm_type == CommType.HOOK
+ ):
return True
- if node.op == 'get_attr' and isinstance(
- node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK:
+ if (
+ node.op == "get_attr"
+ and isinstance(node._meta_data, torch.nn.parameter.Parameter)
+ and comm_action.comm_type == CommType.HOOK
+ ):
return True
return False
@@ -410,7 +416,6 @@ def _filter_param_to_hook(node, op_data, comm_action, name):
if _filter_param_to_hook(node, operation_data, comm_action, name=name):
def wrapper(param, comm_spec, stream, overlap):
-
def hook_fn(grad):
if overlap:
with torch.cuda.stream(stream):
@@ -426,22 +431,26 @@ def _shard_param(param, target_sharding_spec):
# apply the sharding spec of parameters
if target_sharding_spec.dim_partition_dict != {}:
origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {})
- setattr(param, 'sharding_spec', origin_sharding_spec)
+ setattr(param, "sharding_spec", origin_sharding_spec)
# TODO: build a ColoParameter class to manager the distributed parameters
# we could use .data here, because all the operations just happen before the real training
# loop, so we don't need to track these operations in the autograd graph.
param = torch.nn.Parameter(
- shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec,
- target_sharding_spec).detach().clone())
+ shape_consistency_manager.apply_for_autoparallel_runtime(
+ param.data, param.sharding_spec, target_sharding_spec
+ )
+ .detach()
+ .clone()
+ )
return param
for node in nodes:
- if node.op == 'call_module':
+ if node.op == "call_module":
target_module = node.graph.owning_module.get_submodule(node.target)
# TODO: we need to do more actions to take care of the shared parameters.
- if hasattr(target_module, 'processed') and target_module.processed:
+ if hasattr(target_module, "processed") and target_module.processed:
continue
- setattr(target_module, 'processed', True)
+ setattr(target_module, "processed", True)
for name, param in target_module.named_parameters():
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
param = _shard_param(param, target_sharding_spec)
@@ -453,7 +462,7 @@ def _shard_param(param, target_sharding_spec):
# apply the sharding spec of buffers
for name, buffer in target_module.named_buffers():
origin_sharding_spec = ShardingSpec(device_mesh, buffer.shape, {})
- setattr(buffer, 'sharding_spec', origin_sharding_spec)
+ setattr(buffer, "sharding_spec", origin_sharding_spec)
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
buffer_sharded = shape_consistency_manager.apply(buffer, target_sharding_spec)
sharded_buffer_dict[name] = buffer_sharded
@@ -461,7 +470,7 @@ def _shard_param(param, target_sharding_spec):
for name, buffer_sharded in sharded_buffer_dict.items():
setattr(target_module, name, buffer_sharded.detach().clone())
- if node.op == 'get_attr':
+ if node.op == "get_attr":
root = node.graph.owning_module
atoms = node.target.split(".")
attr_len = len(atoms)
@@ -488,16 +497,18 @@ def implicit_comm_action_apply(gm: torch.fx.GraphModule):
"""
replace the origin kernel into kernel with implicit communication inside.
"""
- pass
-def runtime_preparation_pass(gm: torch.fx.GraphModule,
- solution: List[int],
- device_mesh: DeviceMesh,
- strategies_constructor: StrategiesConstructor,
- overlap=False):
+def runtime_preparation_pass(
+ gm: torch.fx.GraphModule,
+ solution: List[int],
+ device_mesh: DeviceMesh,
+ strategies_constructor: StrategiesConstructor,
+ overlap=False,
+):
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = solution_annotation_pass(
- gm, solution, strategies_constructor)
+ gm, solution, strategies_constructor
+ )
gm = size_value_converting_pass(gm, device_mesh)
gm = node_args_converting_pass(gm, device_mesh)
# TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed.
diff --git a/colossalai/auto_parallel/tensor_shard/constants.py b/colossalai/auto_parallel/tensor_shard/constants.py
index 99c124934060..e9c2c8664a61 100644
--- a/colossalai/auto_parallel/tensor_shard/constants.py
+++ b/colossalai/auto_parallel/tensor_shard/constants.py
@@ -3,9 +3,22 @@
import torch
__all__ = [
- 'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'RESHAPE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP',
- 'LINEAR_MODULE_OP', 'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP', 'NON_PARAM_FUNC_OP', 'BCAST_FUNC_OP',
- 'EMBEDDING_MODULE_OP', 'LAYERNORM_MODULE_OP', 'ELEMENTWISE_METHOD_OP', 'RESHAPE_METHOD_OP', 'INFINITY_COST'
+ "ELEMENTWISE_MODULE_OP",
+ "ELEMENTWISE_FUNC_OP",
+ "RESHAPE_FUNC_OP",
+ "CONV_MODULE_OP",
+ "CONV_FUNC_OP",
+ "LINEAR_MODULE_OP",
+ "LINEAR_FUNC_OP",
+ "BATCHNORM_MODULE_OP",
+ "POOL_MODULE_OP",
+ "NON_PARAM_FUNC_OP",
+ "BCAST_FUNC_OP",
+ "EMBEDDING_MODULE_OP",
+ "LAYERNORM_MODULE_OP",
+ "ELEMENTWISE_METHOD_OP",
+ "RESHAPE_METHOD_OP",
+ "INFINITY_COST",
]
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
@@ -18,13 +31,13 @@
torch.nn.functional.relu,
torch.nn.functional.dropout,
# softmax should not be here
- torch.nn.functional.softmax
+ torch.nn.functional.softmax,
]
ELEMENTWISE_METHOD_OP = [
torch.Tensor.to,
torch.Tensor.type,
# TODO: contiguous maybe need some extra processes.
- torch.Tensor.contiguous
+ torch.Tensor.contiguous,
]
RESHAPE_FUNC_OP = [
torch.flatten,
@@ -42,15 +55,36 @@
torch.Tensor.transpose,
]
BCAST_FUNC_OP = [
- torch.add, torch.sub, torch.mul, torch.div, torch.floor_divide, torch.true_divide, operator.add, operator.sub,
- operator.mul, operator.floordiv, operator.truediv, torch.matmul, operator.pow, torch.pow
+ torch.add,
+ torch.sub,
+ torch.mul,
+ torch.div,
+ torch.floor_divide,
+ torch.true_divide,
+ operator.add,
+ operator.sub,
+ operator.mul,
+ operator.floordiv,
+ operator.truediv,
+ torch.matmul,
+ operator.pow,
+ torch.pow,
]
CONV_MODULE_OP = [
- torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d,
- torch.nn.ConvTranspose3d
+ torch.nn.Conv1d,
+ torch.nn.Conv2d,
+ torch.nn.Conv3d,
+ torch.nn.ConvTranspose1d,
+ torch.nn.ConvTranspose2d,
+ torch.nn.ConvTranspose3d,
]
CONV_FUNC_OP = [
- torch.conv1d, torch.conv2d, torch.conv3d, torch.conv_transpose1d, torch.conv_transpose2d, torch.conv_transpose3d
+ torch.conv1d,
+ torch.conv2d,
+ torch.conv3d,
+ torch.conv_transpose1d,
+ torch.conv_transpose2d,
+ torch.conv_transpose3d,
]
EMBEDDING_MODULE_OP = [torch.nn.modules.sparse.Embedding]
LINEAR_MODULE_OP = [torch.nn.Linear]
@@ -85,7 +119,7 @@
operator.floordiv,
operator.truediv,
# softmax should not be here
- torch.nn.functional.softmax
+ torch.nn.functional.softmax,
]
INFINITY_COST = 1e13
diff --git a/colossalai/auto_parallel/tensor_shard/initialize.py b/colossalai/auto_parallel/tensor_shard/initialize.py
index b406ca6fb7e0..d82f0ef53f66 100644
--- a/colossalai/auto_parallel/tensor_shard/initialize.py
+++ b/colossalai/auto_parallel/tensor_shard/initialize.py
@@ -3,7 +3,6 @@
import torch
import torch.distributed as dist
import torch.nn as nn
-from torch.fx import GraphModule
from torch.fx.graph import Graph
from colossalai._analyzer.fx.codegen import ActivationCheckpointCodeGen
@@ -14,27 +13,32 @@
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
from colossalai.auto_parallel.tensor_shard.options import DataloaderOption, ShardOption, SolverOptions, SolverPerference
from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction
-from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
+from colossalai.auto_parallel.tensor_shard.solver import CostGraph, Solver, StrategiesConstructor
from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.sharding_spec import ShardingSpec
class ModuleWrapper(nn.Module):
- '''
+ """
This class is used to wrap the original module, and add the sharding_spec_dict, origin_spec_dict, comm_actions_dict
into the forward function.
- '''
-
- def __init__(self, module: ColoGraphModule, sharding_spec_dict: Dict[int, List[ShardingSpec]],
- origin_spec_dict: Dict[int, ShardingSpec], comm_actions_dict: Dict[int, Dict[str, CommAction]]):
- '''
+ """
+
+ def __init__(
+ self,
+ module: ColoGraphModule,
+ sharding_spec_dict: Dict[int, List[ShardingSpec]],
+ origin_spec_dict: Dict[int, ShardingSpec],
+ comm_actions_dict: Dict[int, Dict[str, CommAction]],
+ ):
+ """
Args:
module: the original module
sharding_spec_dict: The sharding_spec_dict is used to record the target sharding specs of each tensor required in user node.
origin_spec_dict: The origin_spec_dict is used to record the original sharding spec of each tensor.
comm_actions_dict: The comm_actions_dict is used to record the communication actions of each tensor.
- '''
+ """
super(ModuleWrapper, self).__init__()
self.module = module
self.sharding_spec_dict = sharding_spec_dict
@@ -42,67 +46,68 @@ def __init__(self, module: ColoGraphModule, sharding_spec_dict: Dict[int, List[S
self.comm_actions_dict = comm_actions_dict
def forward(self, *args, **kwargs):
- return self.module(*args,
- sharding_spec_convert_dict=self.sharding_spec_dict,
- origin_node_sharding_spec_dict=self.origin_spec_dict,
- comm_actions_dict=self.comm_actions_dict,
- **kwargs)
+ return self.module(
+ *args,
+ sharding_spec_convert_dict=self.sharding_spec_dict,
+ origin_node_sharding_spec_dict=self.origin_spec_dict,
+ comm_actions_dict=self.comm_actions_dict,
+ **kwargs,
+ )
def extract_meta_args_from_dataloader(data_loader: torch.utils.data.DataLoader, data_process_func: callable):
- '''
+ """
This method is used to extract the meta_args from the dataloader under the instruction of the data_process_func.
- '''
+ """
# TODO: implement this function
- pass
def extract_alpha_beta_for_device_mesh(alpha_beta_dict: Dict[Tuple[int], Tuple[float]], logical_mesh_shape: Tuple[int]):
- '''
+ """
This method is used to extract the mesh_alpha and mesh_beta for the given logical_mesh_shape
from the alpha_beta_dict. These two values will be used to estimate the communication cost.
- '''
+ """
# TODO: implement this function
- pass
-def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh, solver_preference: str, dataloader_option: str,
- shard_option: str):
- '''
+def build_strategy_constructor(
+ graph: Graph, device_mesh: DeviceMesh, solver_preference: str, dataloader_option: str, shard_option: str
+):
+ """
This method is used to build the strategy_constructor for the given graph.
After this method, each node in the graph will have a strategies_vector which
is constructed by the related node handler.
- '''
- if solver_preference == 'standard':
+ """
+ if solver_preference == "standard":
solver_preference = SolverPerference.STANDARD
- elif solver_preference == 'tp':
+ elif solver_preference == "tp":
solver_preference = SolverPerference.TP
- elif solver_preference == 'dp':
+ elif solver_preference == "dp":
solver_preference = SolverPerference.DP
else:
- raise ValueError(f'Invalid solver_preference: {solver_preference}')
+ raise ValueError(f"Invalid solver_preference: {solver_preference}")
- if dataloader_option == 'replicated':
+ if dataloader_option == "replicated":
dataloader_option = DataloaderOption.REPLICATED
- elif dataloader_option == 'distributed':
+ elif dataloader_option == "distributed":
dataloader_option = DataloaderOption.DISTRIBUTED
else:
- raise ValueError(f'Invalid dataloader_option: {dataloader_option}')
+ raise ValueError(f"Invalid dataloader_option: {dataloader_option}")
- if shard_option == 'standard':
+ if shard_option == "standard":
shard_option = ShardOption.STANDARD
- elif shard_option == 'shard':
+ elif shard_option == "shard":
shard_option = ShardOption.SHARD
- elif shard_option == 'shard_last_axis':
+ elif shard_option == "shard_last_axis":
shard_option = ShardOption.SHARD_LAST_AXIS
- elif shard_option == 'full_shard':
+ elif shard_option == "full_shard":
shard_option = ShardOption.FULL_SHARD
else:
- raise ValueError(f'Invalid shard_option: {shard_option}')
+ raise ValueError(f"Invalid shard_option: {shard_option}")
- solver_options = SolverOptions(solver_perference=solver_preference,
- dataloader_option=dataloader_option,
- shard_option=shard_option)
+ solver_options = SolverOptions(
+ solver_perference=solver_preference, dataloader_option=dataloader_option, shard_option=shard_option
+ )
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
@@ -110,10 +115,10 @@ def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh, solver_pre
def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstructor, memory_budget: float = -1.0):
- '''
+ """
This method is used to solve the best solution for the given graph.
The solution is a list of integers, each integer represents the best strategy index of the corresponding node.
- '''
+ """
# temporarily we use all nodes as liveness list, we count the backward memory cost together with
# forward memory cost into the node memory cost, and no activation checkpoint is used in this phase.
# graph_analyser = GraphAnalyser(gm)
@@ -127,23 +132,23 @@ def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstruc
return solution
-def transform_to_sharded_model(gm: ColoGraphModule,
- meta_args: Dict,
- solution: List[int],
- device_mesh: DeviceMesh,
- strategies_constructor: StrategiesConstructor,
- overlap: bool = False):
- '''
+def transform_to_sharded_model(
+ gm: ColoGraphModule,
+ meta_args: Dict,
+ solution: List[int],
+ device_mesh: DeviceMesh,
+ strategies_constructor: StrategiesConstructor,
+ overlap: bool = False,
+):
+ """
This method is used to transform the original graph to the sharded graph.
The model parameters will be sharded according to the solution and the grad hooks
will be added to the sharded graph using the runtime_preparation_pass.
The communication node will be added into the graph using the runtime_apply_pass.
- '''
- gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm,
- solution,
- device_mesh,
- strategies_constructor,
- overlap=overlap)
+ """
+ gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(
+ gm, solution, device_mesh, strategies_constructor, overlap=overlap
+ )
gm = runtime_apply_pass(gm)
shape_prop_pass(gm, *meta_args.values(), sharding_spec_dict, origin_spec_dict, comm_actions_dict)
gm.recompile()
@@ -152,12 +157,14 @@ def transform_to_sharded_model(gm: ColoGraphModule,
return gm, sharding_spec_dicts
-def initialize_device_mesh(world_size: int = -1,
- physical_devices: List[int] = None,
- alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
- logical_mesh_shape: Tuple[int] = None,
- logical_mesh_id: torch.Tensor = None):
- '''
+def initialize_device_mesh(
+ world_size: int = -1,
+ physical_devices: List[int] = None,
+ alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
+ logical_mesh_shape: Tuple[int] = None,
+ logical_mesh_id: torch.Tensor = None,
+):
+ """
This method is used to initialize the device mesh.
Args:
@@ -170,7 +177,7 @@ def initialize_device_mesh(world_size: int = -1,
logical_mesh_shape(optional): the logical_mesh_shape is used to specify the logical
mesh shape.
logical_mesh_id(optional): the logical_mesh_id is used to specify the logical mesh id.
- '''
+ """
# if world_size is not set, use the world size from torch.distributed
if world_size == -1:
world_size = dist.get_world_size()
@@ -201,27 +208,31 @@ def initialize_device_mesh(world_size: int = -1,
# extract alpha and beta values for the chosen logical mesh shape
mesh_alpha, mesh_beta = extract_alpha_beta_for_device_mesh(alpha_beta_dict, logical_mesh_id)
- device_mesh = DeviceMesh(physical_mesh_id=physical_mesh,
- logical_mesh_id=logical_mesh_id,
- mesh_alpha=mesh_alpha,
- mesh_beta=mesh_beta,
- init_process_group=True)
+ device_mesh = DeviceMesh(
+ physical_mesh_id=physical_mesh,
+ logical_mesh_id=logical_mesh_id,
+ mesh_alpha=mesh_alpha,
+ mesh_beta=mesh_beta,
+ init_process_group=True,
+ )
return device_mesh
-def initialize_model(model: nn.Module,
- meta_args: Dict[str, torch.Tensor],
- device_mesh: DeviceMesh,
- memory_budget: float = -1.0,
- overlap: bool = False,
- solver_preference: str = 'standard',
- dataloader_option: str = 'replicated',
- shard_option: str = 'standard',
- save_solver_solution: bool = False,
- load_solver_solution: bool = False,
- solution_path: str = None,
- return_solution: bool = False):
- '''
+def initialize_model(
+ model: nn.Module,
+ meta_args: Dict[str, torch.Tensor],
+ device_mesh: DeviceMesh,
+ memory_budget: float = -1.0,
+ overlap: bool = False,
+ solver_preference: str = "standard",
+ dataloader_option: str = "replicated",
+ shard_option: str = "standard",
+ save_solver_solution: bool = False,
+ load_solver_solution: bool = False,
+ solution_path: str = None,
+ return_solution: bool = False,
+):
+ """
This method is used to initialize the sharded model which could be used as normal pytorch model.
Args:
@@ -246,7 +257,7 @@ def initialize_model(model: nn.Module,
return_solution(optional): if the return_solution is True, the solution will be returned. The returned
solution will be used to debug or help to analyze the sharding result. Therefore, we will not just
return a series of integers, but return the best strategies.
- '''
+ """
tracer = ColoTracer(trace_act_ckpt=True, bias_addition_split=True)
graph = tracer.trace(root=model, meta_args=meta_args)
@@ -256,11 +267,13 @@ def initialize_model(model: nn.Module,
shape_prop_pass(gm, *meta_args.values())
gm.recompile()
- strategies_constructor = build_strategy_constructor(graph,
- device_mesh,
- solver_preference=solver_preference,
- dataloader_option=dataloader_option,
- shard_option=shard_option)
+ strategies_constructor = build_strategy_constructor(
+ graph,
+ device_mesh,
+ solver_preference=solver_preference,
+ dataloader_option=dataloader_option,
+ shard_option=shard_option,
+ )
if load_solver_solution:
solution = torch.load(solution_path)
else:
@@ -268,8 +281,9 @@ def initialize_model(model: nn.Module,
if save_solver_solution:
torch.save(solution, solution_path)
- gm, sharding_spec_dicts = transform_to_sharded_model(gm, meta_args, solution, device_mesh, strategies_constructor,
- overlap)
+ gm, sharding_spec_dicts = transform_to_sharded_model(
+ gm, meta_args, solution, device_mesh, strategies_constructor, overlap
+ )
model_to_return = ModuleWrapper(gm, *sharding_spec_dicts)
@@ -277,28 +291,30 @@ def initialize_model(model: nn.Module,
solution_to_return = []
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
for index, node in enumerate(nodes):
- solution_to_return.append(f'{node.name} {node.strategies_vector[solution[index]].name}')
+ solution_to_return.append(f"{node.name} {node.strategies_vector[solution[index]].name}")
return model_to_return, solution_to_return
else:
return model_to_return
-def autoparallelize(model: nn.Module,
- meta_args: Dict[str, torch.Tensor] = None,
- data_loader: torch.utils.data.DataLoader = None,
- data_process_func: callable = None,
- alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
- logical_mesh_shape: Tuple[int] = None,
- logical_mesh_id: torch.Tensor = None,
- solver_preference: str = 'standard',
- dataloader_option: str = 'replicated',
- shard_option: str = 'standard',
- save_solver_solution: bool = False,
- load_solver_solution: bool = False,
- solver_solution_path: str = None,
- return_solution: bool = False,
- memory_budget: float = -1.0):
- '''
+def autoparallelize(
+ model: nn.Module,
+ meta_args: Dict[str, torch.Tensor] = None,
+ data_loader: torch.utils.data.DataLoader = None,
+ data_process_func: callable = None,
+ alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
+ logical_mesh_shape: Tuple[int] = None,
+ logical_mesh_id: torch.Tensor = None,
+ solver_preference: str = "standard",
+ dataloader_option: str = "replicated",
+ shard_option: str = "standard",
+ save_solver_solution: bool = False,
+ load_solver_solution: bool = False,
+ solver_solution_path: str = None,
+ return_solution: bool = False,
+ memory_budget: float = -1.0,
+):
+ """
This method is used to initialize the device mesh, extract the meta_args, and
use them to create a sharded model.
@@ -329,24 +345,26 @@ def autoparallelize(model: nn.Module,
return_solution(optional): if the return_solution is True, the solution will be returned.
memory_budget(optional): the max cuda memory could be used. If the memory budget is -1.0,
the memory budget will be infinity.
- '''
- device_mesh = initialize_device_mesh(alpha_beta_dict=alpha_beta_dict,
- logical_mesh_shape=logical_mesh_shape,
- logical_mesh_id=logical_mesh_id)
+ """
+ device_mesh = initialize_device_mesh(
+ alpha_beta_dict=alpha_beta_dict, logical_mesh_shape=logical_mesh_shape, logical_mesh_id=logical_mesh_id
+ )
if meta_args is None:
meta_args = extract_meta_args_from_dataloader(data_loader, data_process_func)
- rst_to_unpack = initialize_model(model,
- meta_args,
- device_mesh,
- solver_preference=solver_preference,
- dataloader_option=dataloader_option,
- shard_option=shard_option,
- save_solver_solution=save_solver_solution,
- load_solver_solution=load_solver_solution,
- solution_path=solver_solution_path,
- return_solution=return_solution,
- memory_budget=memory_budget)
+ rst_to_unpack = initialize_model(
+ model,
+ meta_args,
+ device_mesh,
+ solver_preference=solver_preference,
+ dataloader_option=dataloader_option,
+ shard_option=shard_option,
+ save_solver_solution=save_solver_solution,
+ load_solver_solution=load_solver_solution,
+ solution_path=solver_solution_path,
+ return_solution=return_solution,
+ memory_budget=memory_budget,
+ )
if return_solution:
model, solution = rst_to_unpack
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py
index 9903ca54e52c..aa2e5e9c40c0 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py
@@ -25,11 +25,33 @@
from .where_handler import WhereHandler
__all__ = [
- 'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'AddBMMFunctionHandler',
- 'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler',
- 'UnaryElementwiseHandler', 'DefaultReshapeHandler', 'PlaceholderHandler', 'OutputHandler', 'WhereHandler',
- 'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler',
- 'GetItemHandler', 'GetattrHandler', 'ViewHandler', 'PermuteHandler', 'TensorConstructorHandler',
- 'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler', 'TransposeHandler',
- 'SplitHandler'
+ "LinearFunctionHandler",
+ "LinearModuleHandler",
+ "BMMFunctionHandler",
+ "AddBMMFunctionHandler",
+ "LayerNormModuleHandler",
+ "BatchNormModuleHandler",
+ "ConvModuleHandler",
+ "ConvFunctionHandler",
+ "UnaryElementwiseHandler",
+ "DefaultReshapeHandler",
+ "PlaceholderHandler",
+ "OutputHandler",
+ "WhereHandler",
+ "NormPoolingHandler",
+ "BinaryElementwiseHandler",
+ "MatMulHandler",
+ "operator_registry",
+ "ADDMMFunctionHandler",
+ "GetItemHandler",
+ "GetattrHandler",
+ "ViewHandler",
+ "PermuteHandler",
+ "TensorConstructorHandler",
+ "EmbeddingModuleHandler",
+ "EmbeddingFunctionHandler",
+ "SumHandler",
+ "SoftmaxHandler",
+ "TransposeHandler",
+ "SplitHandler",
]
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/addmm_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/addmm_handler.py
index da0d199c5e05..47c654d6aa43 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/addmm_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/addmm_handler.py
@@ -2,15 +2,13 @@
import torch
-from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager
-
-from ..sharding_strategy import CommAction, CommType, OperationData, OperationDataType, ShardingStrategy
+from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import LinearProjectionStrategyGenerator, StrategyGenerator
-__all__ = ['ADDMMFunctionHandler']
+__all__ = ["ADDMMFunctionHandler"]
@operator_registry.register(torch.addmm)
@@ -30,25 +28,26 @@ def _infer_op_data_type(self, tensor: torch.Tensor) -> OperationDataType:
return data_type
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
-
# input operand
input_data = self.node.args[1]._meta_data
- physical_input_operand = OperationData(name=str(self.node.args[1]),
- type=self._infer_op_data_type(input_data),
- data=input_data)
+ physical_input_operand = OperationData(
+ name=str(self.node.args[1]), type=self._infer_op_data_type(input_data), data=input_data
+ )
# other operand
other_data = self.node.args[2]._meta_data
- physical_other_operand = OperationData(name=str(self.node.args[2]),
- type=self._infer_op_data_type(other_data),
- data=other_data)
+ physical_other_operand = OperationData(
+ name=str(self.node.args[2]), type=self._infer_op_data_type(other_data), data=other_data
+ )
# bias physical shape
bias_logical_shape = self.node._meta_data.shape
bias_data = self.node.args[0]._meta_data
- physical_bias_operand = OperationData(name=str(self.node.args[0]),
- type=self._infer_op_data_type(bias_data),
- data=bias_data,
- logical_shape=bias_logical_shape)
+ physical_bias_operand = OperationData(
+ name=str(self.node.args[0]),
+ type=self._infer_op_data_type(bias_data),
+ data=bias_data,
+ logical_shape=bias_logical_shape,
+ )
# output
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
@@ -57,7 +56,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
"input": physical_input_operand,
"other": physical_other_operand,
"output": physical_output,
- 'bias': physical_bias_operand
+ "bias": physical_bias_operand,
}
return mapping
@@ -66,26 +65,27 @@ def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(
- LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='addmm'))
+ LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type="addmm")
+ )
return generators
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
# convert bias from its logical sharding spec to its physical sharding spec
op_data_mapping = self.get_operation_data_mapping()
- bias_op_data = op_data_mapping['bias']
+ bias_op_data = op_data_mapping["bias"]
bias_physical_shape = bias_op_data.data.shape
bias_logical_shape = bias_op_data.logical_shape
bias_sharding_spec = strategy.get_sharding_spec_by_name(bias_op_data.name)
bias_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
- bias_sharding_spec, bias_logical_shape, bias_physical_shape)
+ bias_sharding_spec, bias_logical_shape, bias_physical_shape
+ )
strategy.sharding_specs[bias_op_data] = bias_sharding_spec
if len(removed_dims) > 0:
- comm_action = comm_actions_for_oprands(node=self.node,
- removed_dims=removed_dims,
- op_data=bias_op_data,
- sharding_spec=bias_sharding_spec)
+ comm_action = comm_actions_for_oprands(
+ node=self.node, removed_dims=removed_dims, op_data=bias_op_data, sharding_spec=bias_sharding_spec
+ )
strategy.communication_actions[bias_op_data] = comm_action
return strategy
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py
index cb1bb36b7879..df4b1d6cef3f 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py
@@ -2,12 +2,12 @@
import torch
-from ..sharding_strategy import OperationData, OperationDataType, StrategiesVector
-from .node_handler import MetaInfoModuleHandler, ModuleHandler
+from ..sharding_strategy import OperationData, OperationDataType
+from .node_handler import MetaInfoModuleHandler
from .registry import operator_registry
from .strategy import BatchNormStrategyGenerator, StrategyGenerator
-__all__ = ['BatchNormModuleHandler']
+__all__ = ["BatchNormModuleHandler"]
@operator_registry.register(torch.nn.BatchNorm1d)
@@ -27,30 +27,37 @@ def get_strategy_generator(self) -> List[StrategyGenerator]:
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
- physical_input_operand = OperationData(name=str(self.node.args[0]),
- type=OperationDataType.ARG,
- data=self.node.args[0]._meta_data)
- physical_other_operand = OperationData(name="weight",
- type=OperationDataType.PARAM,
- data=self.named_parameters['weight'],
- logical_shape=self.named_parameters['weight'].shape)
+ physical_input_operand = OperationData(
+ name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data
+ )
+ physical_other_operand = OperationData(
+ name="weight",
+ type=OperationDataType.PARAM,
+ data=self.named_parameters["weight"],
+ logical_shape=self.named_parameters["weight"].shape,
+ )
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
- physical_running_mean_operand = OperationData(name="running_mean",
- type=OperationDataType.BUFFER,
- data=self.named_buffers['running_mean'],
- logical_shape=self.named_buffers['running_mean'].shape)
+ physical_running_mean_operand = OperationData(
+ name="running_mean",
+ type=OperationDataType.BUFFER,
+ data=self.named_buffers["running_mean"],
+ logical_shape=self.named_buffers["running_mean"].shape,
+ )
- physical_running_var_operand = OperationData(name="running_var",
- type=OperationDataType.BUFFER,
- data=self.named_buffers['running_var'],
- logical_shape=self.named_buffers['running_var'].shape)
+ physical_running_var_operand = OperationData(
+ name="running_var",
+ type=OperationDataType.BUFFER,
+ data=self.named_buffers["running_var"],
+ logical_shape=self.named_buffers["running_var"].shape,
+ )
physical_num_batches_tracked_operand = OperationData(
name="num_batches_tracked",
type=OperationDataType.BUFFER,
- data=self.named_buffers['num_batches_tracked'],
- logical_shape=self.named_buffers['num_batches_tracked'].shape)
+ data=self.named_buffers["num_batches_tracked"],
+ logical_shape=self.named_buffers["num_batches_tracked"].shape,
+ )
mapping = {
"input": physical_input_operand,
@@ -58,12 +65,12 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
"output": physical_output,
"running_mean": physical_running_mean_operand,
"running_var": physical_running_var_operand,
- "num_batches_tracked": physical_num_batches_tracked_operand
+ "num_batches_tracked": physical_num_batches_tracked_operand,
}
- if self.named_parameters['bias'] is not None:
- physical_bias_operand = OperationData(name="bias",
- type=OperationDataType.PARAM,
- data=self.named_parameters['bias'])
- mapping['bias'] = physical_bias_operand
+ if self.named_parameters["bias"] is not None:
+ physical_bias_operand = OperationData(
+ name="bias", type=OperationDataType.PARAM, data=self.named_parameters["bias"]
+ )
+ mapping["bias"] = physical_bias_operand
return mapping
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py
index db8f0b54ddee..f8c137348353 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py
@@ -4,15 +4,14 @@
from torch.fx.node import Node
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, ShardingStrategy
-from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager
from ..constants import BCAST_FUNC_OP
from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape
-from .node_handler import MetaInfoNodeHandler, NodeHandler
+from .node_handler import MetaInfoNodeHandler
from .registry import operator_registry
from .strategy import BinaryElementwiseStrategyGenerator, StrategyGenerator
-__all__ = ['BinaryElementwiseHandler']
+__all__ = ["BinaryElementwiseHandler"]
@operator_registry.register(BCAST_FUNC_OP)
@@ -38,7 +37,7 @@ def _get_arg_value(idx):
# The meta_data of node type argument could also possibly be a non-tensor object.
if not isinstance(meta_data, torch.Tensor):
assert isinstance(meta_data, (int, float))
- meta_data = torch.Tensor([meta_data]).to('meta')
+ meta_data = torch.Tensor([meta_data]).to("meta")
non_tensor = True
else:
@@ -46,7 +45,7 @@ def _get_arg_value(idx):
# but we can deem it as meta data
# as it won't affect the strategy generation
assert isinstance(self.node.args[idx], (int, float))
- meta_data = torch.Tensor([self.node.args[idx]]).to('meta')
+ meta_data = torch.Tensor([self.node.args[idx]]).to("meta")
non_tensor = True
return meta_data, non_tensor
@@ -58,24 +57,27 @@ def _get_arg_value(idx):
# and filter the non-tensor op_data in post_process.
self.non_tensor_list = []
# assert False
- input_op_data = OperationData(name=str(self.node.args[0]),
- type=_get_op_data_type(input_meta_data),
- data=input_meta_data,
- logical_shape=bcast_shape)
- other_op_data = OperationData(name=str(self.node.args[1]),
- type=_get_op_data_type(other_meta_data),
- data=other_meta_data,
- logical_shape=bcast_shape)
- output_op_data = OperationData(name=str(self.node),
- type=OperationDataType.OUTPUT,
- data=output_meta_data,
- logical_shape=bcast_shape)
+ input_op_data = OperationData(
+ name=str(self.node.args[0]),
+ type=_get_op_data_type(input_meta_data),
+ data=input_meta_data,
+ logical_shape=bcast_shape,
+ )
+ other_op_data = OperationData(
+ name=str(self.node.args[1]),
+ type=_get_op_data_type(other_meta_data),
+ data=other_meta_data,
+ logical_shape=bcast_shape,
+ )
+ output_op_data = OperationData(
+ name=str(self.node), type=OperationDataType.OUTPUT, data=output_meta_data, logical_shape=bcast_shape
+ )
if non_tensor_input:
self.non_tensor_list.append(input_op_data)
if non_tensor_other:
self.non_tensor_list.append(other_op_data)
- mapping = {'input': input_op_data, 'other': other_op_data, 'output': output_op_data}
+ mapping = {"input": input_op_data, "other": other_op_data, "output": output_op_data}
return mapping
def get_strategy_generator(self) -> List[StrategyGenerator]:
@@ -100,14 +102,14 @@ def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, Li
logical_shape = op_data.logical_shape
sharding_spec = strategy.get_sharding_spec_by_name(op_data.name)
sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
- sharding_spec, logical_shape, physical_shape)
+ sharding_spec, logical_shape, physical_shape
+ )
strategy.sharding_specs[op_data] = sharding_spec
if len(removed_dims) > 0:
- comm_action = comm_actions_for_oprands(node=self.node,
- removed_dims=removed_dims,
- op_data=op_data,
- sharding_spec=sharding_spec)
+ comm_action = comm_actions_for_oprands(
+ node=self.node, removed_dims=removed_dims, op_data=op_data, sharding_spec=sharding_spec
+ )
strategy.communication_actions[op_data] = comm_action
return strategy
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py
index da2b733c9f7a..5c22ac7bef11 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py
@@ -2,15 +2,13 @@
import torch
-from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager
-
-from ..sharding_strategy import CommAction, CommType, OperationData, OperationDataType, ShardingStrategy
+from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import BatchedMatMulStrategyGenerator, StrategyGenerator
-__all__ = ['BMMFunctionHandler', 'AddBMMFunctionHandler']
+__all__ = ["BMMFunctionHandler", "AddBMMFunctionHandler"]
def _get_data_mapping_for_bmm_op(node, input_idx, other_idx, bias_idx=None):
@@ -19,14 +17,14 @@ def _get_data_mapping_for_bmm_op(node, input_idx, other_idx, bias_idx=None):
node handler to reduce code redundancy.
"""
# input operand
- physical_input_operand = OperationData(name=str(node.args[input_idx]),
- type=OperationDataType.ARG,
- data=node.args[input_idx]._meta_data)
+ physical_input_operand = OperationData(
+ name=str(node.args[input_idx]), type=OperationDataType.ARG, data=node.args[input_idx]._meta_data
+ )
# other operand
- physical_other_operand = OperationData(name=str(node.args[other_idx]),
- type=OperationDataType.ARG,
- data=node.args[other_idx]._meta_data)
+ physical_other_operand = OperationData(
+ name=str(node.args[other_idx]), type=OperationDataType.ARG, data=node.args[other_idx]._meta_data
+ )
# output
physical_output = OperationData(name=str(node), type=OperationDataType.OUTPUT, data=node._meta_data)
@@ -35,11 +33,13 @@ def _get_data_mapping_for_bmm_op(node, input_idx, other_idx, bias_idx=None):
if bias_idx is not None:
# bias physical shape
bias_logical_shape = node._meta_data.shape
- physical_bias_operand = OperationData(name=str(node.args[bias_idx]),
- type=OperationDataType.ARG,
- data=node.args[bias_idx]._meta_data,
- logical_shape=bias_logical_shape)
- mapping['bias'] = physical_bias_operand
+ physical_bias_operand = OperationData(
+ name=str(node.args[bias_idx]),
+ type=OperationDataType.ARG,
+ data=node.args[bias_idx]._meta_data,
+ logical_shape=bias_logical_shape,
+ )
+ mapping["bias"] = physical_bias_operand
return mapping
@@ -91,20 +91,20 @@ def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, Li
# convert bias from its logical sharding spec to its physical sharding spec
op_data_mapping = self.get_operation_data_mapping()
- if 'bias' in op_data_mapping:
- bias_op_data = op_data_mapping['bias']
+ if "bias" in op_data_mapping:
+ bias_op_data = op_data_mapping["bias"]
bias_physical_shape = bias_op_data.data.shape
bias_logical_shape = bias_op_data.logical_shape
bias_sharding_spec = strategy.get_sharding_spec_by_name(bias_op_data.name)
bias_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
- bias_sharding_spec, bias_logical_shape, bias_physical_shape)
+ bias_sharding_spec, bias_logical_shape, bias_physical_shape
+ )
strategy.sharding_specs[bias_op_data] = bias_sharding_spec
if len(removed_dims) > 0:
- comm_action = comm_actions_for_oprands(node=self.node,
- removed_dims=removed_dims,
- op_data=bias_op_data,
- sharding_spec=bias_sharding_spec)
+ comm_action = comm_actions_for_oprands(
+ node=self.node, removed_dims=removed_dims, op_data=bias_op_data, sharding_spec=bias_sharding_spec
+ )
strategy.communication_actions[bias_op_data] = comm_action
return strategy
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py
index 272b1c85630a..fd7c1f837a5a 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py
@@ -3,13 +3,13 @@
import torch
import torch.nn.functional as F
-from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector
+from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from ..utils import transpose_partition_dim
-from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler, ModuleHandler, NodeHandler
+from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler
from .registry import operator_registry
from .strategy import ConvStrategyGenerator, StrategyGenerator
-__all__ = ['ConvModuleHandler', 'ConvFunctionHandler']
+__all__ = ["ConvModuleHandler", "ConvFunctionHandler"]
@operator_registry.register(torch.nn.Conv1d)
@@ -29,25 +29,29 @@ def get_strategy_generator(self) -> List[StrategyGenerator]:
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
- physical_input_operand = OperationData(name=str(self.node.args[0]),
- type=OperationDataType.ARG,
- data=self.node.args[0]._meta_data)
+ physical_input_operand = OperationData(
+ name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data
+ )
logical_shape_for_weight = list(self.named_parameters["weight"].shape)
- logical_shape_for_weight[0], logical_shape_for_weight[1] = logical_shape_for_weight[
- 1], logical_shape_for_weight[0]
- physical_other_operand = OperationData(name="weight",
- type=OperationDataType.PARAM,
- data=self.named_parameters['weight'],
- logical_shape=torch.Size(logical_shape_for_weight))
+ logical_shape_for_weight[0], logical_shape_for_weight[1] = (
+ logical_shape_for_weight[1],
+ logical_shape_for_weight[0],
+ )
+ physical_other_operand = OperationData(
+ name="weight",
+ type=OperationDataType.PARAM,
+ data=self.named_parameters["weight"],
+ logical_shape=torch.Size(logical_shape_for_weight),
+ )
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
if "bias" in self.named_parameters:
- physical_bias_operand = OperationData(name="bias",
- type=OperationDataType.PARAM,
- data=self.named_parameters['bias'])
- mapping['bias'] = physical_bias_operand
+ physical_bias_operand = OperationData(
+ name="bias", type=OperationDataType.PARAM, data=self.named_parameters["bias"]
+ )
+ mapping["bias"] = physical_bias_operand
return mapping
def post_process(self, strategy: ShardingStrategy):
@@ -77,9 +81,9 @@ def get_strategy_generator(self) -> List[StrategyGenerator]:
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
- physical_input_operand = OperationData(name=str(self.node.args[0]),
- type=OperationDataType.ARG,
- data=self.node.args[0]._meta_data)
+ physical_input_operand = OperationData(
+ name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data
+ )
# check if the other operand is a parameter
if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter):
@@ -88,26 +92,30 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
data_type = OperationDataType.ARG
logical_shape_for_weight = list(self.node.args[1]._meta_data.shape)
- logical_shape_for_weight[0], logical_shape_for_weight[1] = logical_shape_for_weight[
- 1], logical_shape_for_weight[0]
- physical_other_operand = OperationData(name=str(self.node.args[1]),
- type=data_type,
- data=self.node.args[1]._meta_data,
- logical_shape=torch.Size(logical_shape_for_weight))
+ logical_shape_for_weight[0], logical_shape_for_weight[1] = (
+ logical_shape_for_weight[1],
+ logical_shape_for_weight[0],
+ )
+ physical_other_operand = OperationData(
+ name=str(self.node.args[1]),
+ type=data_type,
+ data=self.node.args[1]._meta_data,
+ logical_shape=torch.Size(logical_shape_for_weight),
+ )
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
- if "bias" in self.node.kwargs and self.node.kwargs['bias'] is not None:
+ if "bias" in self.node.kwargs and self.node.kwargs["bias"] is not None:
# check if the other operand is a parameter
if isinstance(self.node.kwargs["bias"]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
- physical_bias_operand = OperationData(name=str(self.node.kwargs["bias"]),
- type=data_type,
- data=self.node.kwargs["bias"]._meta_data)
- mapping['bias'] = physical_bias_operand
+ physical_bias_operand = OperationData(
+ name=str(self.node.kwargs["bias"]), type=data_type, data=self.node.kwargs["bias"]._meta_data
+ )
+ mapping["bias"] = physical_bias_operand
return mapping
def post_process(self, strategy: ShardingStrategy):
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/default_reshape_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/default_reshape_handler.py
index 0c5b9f39e1fb..feb1032a6c0f 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/default_reshape_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/default_reshape_handler.py
@@ -3,11 +3,11 @@
import torch
from ..sharding_strategy import OperationData, OperationDataType
-from .node_handler import MetaInfoNodeHandler, NodeHandler
+from .node_handler import MetaInfoNodeHandler
from .registry import operator_registry
from .strategy import DefaultReshapeGenerator, StrategyGenerator
-__all__ = ['DefaultReshapeHandler']
+__all__ = ["DefaultReshapeHandler"]
@operator_registry.register(torch.flatten)
@@ -54,17 +54,15 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
input_data = self.node.args[0]._meta_data
input_logical_shape = self.infer_logical_shape(input_data)
- physical_input_operand = OperationData(name=str(self.node.args[0]),
- type=data_type,
- data=input_data,
- logical_shape=input_logical_shape)
+ physical_input_operand = OperationData(
+ name=str(self.node.args[0]), type=data_type, data=input_data, logical_shape=input_logical_shape
+ )
output_data = self.node._meta_data
output_logical_shape = self.infer_logical_shape(output_data)
- physical_output = OperationData(name=str(self.node),
- type=OperationDataType.OUTPUT,
- data=output_data,
- logical_shape=output_logical_shape)
+ physical_output = OperationData(
+ name=str(self.node), type=OperationDataType.OUTPUT, data=output_data, logical_shape=output_logical_shape
+ )
mapping = {"input": physical_input_operand, "output": physical_output}
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/embedding_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/embedding_handler.py
index 112ee194b4ec..f29c3a0b7d5d 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/embedding_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/embedding_handler.py
@@ -12,11 +12,12 @@
from .registry import operator_registry
from .strategy import EmbeddingStrategyGenerator, StrategyGenerator
-__all__ = ['EmbeddingModuleHandler', 'EmbeddingFunctionHandler']
+__all__ = ["EmbeddingModuleHandler", "EmbeddingFunctionHandler"]
-def _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy: ShardingStrategy, input_name: str,
- output_name: str) -> List[ShardingStrategy]:
+def _convert_logical_sharding_to_physical_sharding_spec_for_embedding(
+ strategy: ShardingStrategy, input_name: str, output_name: str
+) -> List[ShardingStrategy]:
"""
This function converts the logical sharding spec to the physical sharding spec for both the input and output
of the embedding operation.
@@ -56,27 +57,31 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy:
output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name)
try:
# replace the 0th dimension in the logical sharding with ith dimension in the physical sharding
- update_partition_dim(sharding_spec=input_sharding_spec,
- dim_mapping={0: i},
- physical_shape=input_op_data.data.shape,
- inplace=True)
+ update_partition_dim(
+ sharding_spec=input_sharding_spec,
+ dim_mapping={0: i},
+ physical_shape=input_op_data.data.shape,
+ inplace=True,
+ )
if last_logical_output_dims in output_sharding_spec.dim_partition_dict:
dim_mapping = {0: i, last_logical_output_dims: last_physical_output_dims}
else:
dim_mapping = {0: i}
- update_partition_dim(sharding_spec=output_sharding_spec,
- dim_mapping=dim_mapping,
- physical_shape=output_op_data.data.shape,
- inplace=True)
+ update_partition_dim(
+ sharding_spec=output_sharding_spec,
+ dim_mapping=dim_mapping,
+ physical_shape=output_op_data.data.shape,
+ inplace=True,
+ )
- strategy_copy.name = f'{strategy.name}_{i}'
+ strategy_copy.name = f"{strategy.name}_{i}"
sharding_strategies.append(strategy_copy)
except ShardingNotDivisibleError as e:
logger.debug(
- f'Errored occurred when converting the logical sharding spec to the physical one. Error details: {e}'
+ f"Errored occurred when converting the logical sharding spec to the physical one. Error details: {e}"
)
else:
# the generated sharding strategy does not shard the non-matrix dimension,
@@ -87,20 +92,21 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy:
output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name)
# after updating, the logical shape will be replaced by the physical shape
- update_partition_dim(sharding_spec=input_sharding_spec,
- dim_mapping={},
- physical_shape=input_op_data.data.shape,
- inplace=True)
+ update_partition_dim(
+ sharding_spec=input_sharding_spec, dim_mapping={}, physical_shape=input_op_data.data.shape, inplace=True
+ )
if last_logical_output_dims in output_sharding_spec.dim_partition_dict:
dim_mapping = {last_logical_output_dims: last_physical_output_dims}
else:
dim_mapping = {}
- update_partition_dim(sharding_spec=output_sharding_spec,
- dim_mapping=dim_mapping,
- physical_shape=output_op_data.data.shape,
- inplace=True)
+ update_partition_dim(
+ sharding_spec=output_sharding_spec,
+ dim_mapping=dim_mapping,
+ physical_shape=output_op_data.data.shape,
+ inplace=True,
+ )
sharding_strategies.append(strategy_copy)
return sharding_strategies
@@ -125,14 +131,16 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# Finally, the input will be transformed back to its original shape in self.post_process
input_meta_data = self.node.args[0]._meta_data
input_logical_shape = input_meta_data.view(-1).shape
- physical_input_operand = OperationData(name=str(self.node.args[0]),
- type=OperationDataType.ARG,
- data=input_meta_data,
- logical_shape=input_logical_shape)
+ physical_input_operand = OperationData(
+ name=str(self.node.args[0]),
+ type=OperationDataType.ARG,
+ data=input_meta_data,
+ logical_shape=input_logical_shape,
+ )
- physical_other_operand = OperationData(name="weight",
- type=OperationDataType.PARAM,
- data=self.named_parameters['weight'])
+ physical_other_operand = OperationData(
+ name="weight", type=OperationDataType.PARAM, data=self.named_parameters["weight"]
+ )
# Same as input, in nn.Embedding operation, all the dimensions of output will be treated as
# (batch dimension, embedding dimension), and then the sharding spec will be generated based
@@ -141,10 +149,12 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# Finally, the output will be transformed back to its original shape in self.post_process
output_meta_data = self.node._meta_data
output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape
- physical_output = OperationData(name=str(self.node),
- type=OperationDataType.OUTPUT,
- data=output_meta_data,
- logical_shape=output_logical_shape)
+ physical_output = OperationData(
+ name=str(self.node),
+ type=OperationDataType.OUTPUT,
+ data=output_meta_data,
+ logical_shape=output_logical_shape,
+ )
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
@@ -157,10 +167,9 @@ def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, Li
# create multiple sharding strategies for the inputs
# as input can be multi-dimensional and the partition dim is only 2D,
# we need to map the partition at logical dim 0 to one of the first few dimensions of the input and output
- strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy=strategy,
- input_name=str(
- self.node.args[0]),
- output_name=str(self.node))
+ strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding(
+ strategy=strategy, input_name=str(self.node.args[0]), output_name=str(self.node)
+ )
return strategies
@@ -183,10 +192,12 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# Finally, the input will be transformed back to its original shape in self.post_process
input_meta_data = self.node.args[0]._meta_data
input_logical_shape = input_meta_data.view(-1).shape
- physical_input_operand = OperationData(name=str(self.node.args[0]),
- type=OperationDataType.ARG,
- data=self.node.args[0]._meta_data,
- logical_shape=input_logical_shape)
+ physical_input_operand = OperationData(
+ name=str(self.node.args[0]),
+ type=OperationDataType.ARG,
+ data=self.node.args[0]._meta_data,
+ logical_shape=input_logical_shape,
+ )
# check if the other operand is a parameter
if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter):
@@ -194,9 +205,9 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
else:
data_type = OperationDataType.ARG
- physical_other_operand = OperationData(name=str(self.node.args[1]),
- type=data_type,
- data=self.node.args[1]._meta_data)
+ physical_other_operand = OperationData(
+ name=str(self.node.args[1]), type=data_type, data=self.node.args[1]._meta_data
+ )
# Same as input, in F.embedding operation, all the dimensions of output will be treated as
# (batch dimension, embedding dimension), and then the sharding spec will be generated based
@@ -223,8 +234,7 @@ def post_process(self, strategy: ShardingStrategy):
# create multiple sharding strategies for the inputs
# as input can be multi-dimensional and the partition dim is only 2D,
# we need to map the partition at logical dim 0 to one of the first few dimensions of the input and output
- strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy=strategy,
- input_name=str(
- self.node.args[0]),
- output_name=str(self.node))
+ strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding(
+ strategy=strategy, input_name=str(self.node.args[0]), output_name=str(self.node)
+ )
return strategies
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/getattr_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/getattr_handler.py
index 53addb873d1d..dcf0a1760a2c 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/getattr_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/getattr_handler.py
@@ -4,7 +4,7 @@
from .node_handler import NodeHandler
from .strategy import GetattrGenerator, StrategyGenerator
-__all__ = ['GetattrHandler']
+__all__ = ["GetattrHandler"]
class GetattrHandler(NodeHandler):
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/getitem_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/getitem_handler.py
index 3466e9dd9940..bd342c12eda9 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/getitem_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/getitem_handler.py
@@ -8,7 +8,7 @@
from .registry import operator_registry
from .strategy import StrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator
-__all__ = ['GetItemHandler']
+__all__ = ["GetItemHandler"]
@operator_registry.register(operator.getitem)
@@ -30,9 +30,9 @@ def get_strategy_generator(self) -> List[StrategyGenerator]:
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
- physical_input_operand = OperationData(name=str(self.node.args[0]),
- type=OperationDataType.ARG,
- data=self.node.args[0]._meta_data)
+ physical_input_operand = OperationData(
+ name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data
+ )
physical_other_operand = OperationData(name="index", type=OperationDataType.ARG, data=self.node.args[1])
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/layer_norm_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/layer_norm_handler.py
index 452381169b74..ce6b20fa1d24 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/layer_norm_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/layer_norm_handler.py
@@ -3,11 +3,11 @@
import torch
from ..sharding_strategy import OperationData, OperationDataType
-from .node_handler import MetaInfoModuleHandler, ModuleHandler
+from .node_handler import MetaInfoModuleHandler
from .registry import operator_registry
from .strategy import LayerNormGenerator, StrategyGenerator
-__all__ = ['LayerNormModuleHandler']
+__all__ = ["LayerNormModuleHandler"]
@operator_registry.register(torch.nn.LayerNorm)
@@ -25,20 +25,22 @@ def get_strategy_generator(self) -> List[StrategyGenerator]:
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
- physical_input_operand = OperationData(name=str(self.node.args[0]),
- type=OperationDataType.ARG,
- data=self.node.args[0]._meta_data)
- physical_other_operand = OperationData(name="weight",
- type=OperationDataType.PARAM,
- data=self.named_parameters['weight'],
- logical_shape=self.named_parameters['weight'].shape)
+ physical_input_operand = OperationData(
+ name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data
+ )
+ physical_other_operand = OperationData(
+ name="weight",
+ type=OperationDataType.PARAM,
+ data=self.named_parameters["weight"],
+ logical_shape=self.named_parameters["weight"].shape,
+ )
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
- if self.named_parameters['bias'] is not None:
- physical_bias_operand = OperationData(name="bias",
- type=OperationDataType.PARAM,
- data=self.named_parameters['bias'])
- mapping['bias'] = physical_bias_operand
+ if self.named_parameters["bias"] is not None:
+ physical_bias_operand = OperationData(
+ name="bias", type=OperationDataType.PARAM, data=self.named_parameters["bias"]
+ )
+ mapping["bias"] = physical_bias_operand
return mapping
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py
index ea541e434009..4177af4eaf71 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py
@@ -3,24 +3,21 @@
import torch
import torch.nn.functional as F
-from colossalai.auto_parallel.tensor_shard.utils import (
- check_sharding_spec_validity,
- transpose_partition_dim,
- update_partition_dim,
-)
+from colossalai.auto_parallel.tensor_shard.utils import transpose_partition_dim, update_partition_dim
from colossalai.logging import get_dist_logger
from colossalai.tensor.sharding_spec import ShardingNotDivisibleError
-from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector
-from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler, ModuleHandler, NodeHandler
+from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
+from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler
from .registry import operator_registry
from .strategy import LinearProjectionStrategyGenerator, StrategyGenerator
-__all__ = ['LinearModuleHandler', 'LinearFunctionHandler']
+__all__ = ["LinearModuleHandler", "LinearFunctionHandler"]
-def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStrategy,
- weight_name: str) -> ShardingStrategy:
+def _update_sharding_spec_for_transposed_weight_for_linear(
+ strategy: ShardingStrategy, weight_name: str
+) -> ShardingStrategy:
"""
This function is a helper function used by both module node handler and function node handler. This function will
convert the sharding spec for the transposed weight to the correct partition spec.
@@ -32,16 +29,17 @@ def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStr
# switch the dimensions of the transposed weight
sharding_spec = strategy.get_sharding_spec_by_name(weight_name)
op_data = strategy.get_op_data_by_name(weight_name)
- assert op_data.logical_shape[0] == op_data.data.shape[1] and \
- op_data.logical_shape[1] == op_data.data.shape[0], \
- "Expected the logical shape of the linear operator's weight is equal to transposed physical shape"
+ assert (
+ op_data.logical_shape[0] == op_data.data.shape[1] and op_data.logical_shape[1] == op_data.data.shape[0]
+ ), "Expected the logical shape of the linear operator's weight is equal to transposed physical shape"
dim_size = len(op_data.logical_shape)
transpose_partition_dim(sharding_spec, 0, dim_size - 1)
return strategy
-def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: ShardingStrategy, input_name: str,
- output_name: str) -> List[ShardingStrategy]:
+def _convert_logical_sharding_to_physical_sharding_spec_for_linear(
+ strategy: ShardingStrategy, input_name: str, output_name: str
+) -> List[ShardingStrategy]:
"""
This function converts the logical sharding spec to the physical sharding spec for both the input and output of the linear operation. The input and output
should have the same sharding spec.
@@ -99,22 +97,26 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
input_dim_mapping = {0: i}
input_dim_mapping.update(input_last_dim_mapping)
- update_partition_dim(sharding_spec=input_sharding_spec,
- dim_mapping=input_dim_mapping,
- physical_shape=input_op_data.data.shape,
- inplace=True)
+ update_partition_dim(
+ sharding_spec=input_sharding_spec,
+ dim_mapping=input_dim_mapping,
+ physical_shape=input_op_data.data.shape,
+ inplace=True,
+ )
output_dim_mapping = {0: i}
output_dim_mapping.update(output_last_dim_mapping)
- update_partition_dim(sharding_spec=output_sharding_spec,
- dim_mapping=output_dim_mapping,
- physical_shape=output_op_data.data.shape,
- inplace=True)
- strategy_copy.name = f'{strategy.name}_{i}'
+ update_partition_dim(
+ sharding_spec=output_sharding_spec,
+ dim_mapping=output_dim_mapping,
+ physical_shape=output_op_data.data.shape,
+ inplace=True,
+ )
+ strategy_copy.name = f"{strategy.name}_{i}"
sharding_strategies.append(strategy_copy)
except ShardingNotDivisibleError as e:
logger.debug(
- f'Errored occurred when converting the logical sharding spec to the physical one. Error details: {e}'
+ f"Errored occurred when converting the logical sharding spec to the physical one. Error details: {e}"
)
else:
# the generated sharding strategy does not shard the non-matrix dimension,
@@ -127,17 +129,21 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
# after updating, the logical shape will be replaced by the physical shape
input_dim_mapping = {}
input_dim_mapping.update(input_last_dim_mapping)
- update_partition_dim(sharding_spec=input_sharding_spec,
- dim_mapping=input_dim_mapping,
- physical_shape=input_op_data.data.shape,
- inplace=True)
+ update_partition_dim(
+ sharding_spec=input_sharding_spec,
+ dim_mapping=input_dim_mapping,
+ physical_shape=input_op_data.data.shape,
+ inplace=True,
+ )
output_dim_mapping = {}
output_dim_mapping.update(output_last_dim_mapping)
- update_partition_dim(sharding_spec=output_sharding_spec,
- dim_mapping=output_dim_mapping,
- physical_shape=output_op_data.data.shape,
- inplace=True)
+ update_partition_dim(
+ sharding_spec=output_sharding_spec,
+ dim_mapping=output_dim_mapping,
+ physical_shape=output_op_data.data.shape,
+ inplace=True,
+ )
sharding_strategies.append(strategy_copy)
return sharding_strategies
@@ -152,10 +158,13 @@ def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(
- LinearProjectionStrategyGenerator(op_data_mapping,
- self.device_mesh,
- linear_projection_type='linear',
- solver_perference=self.solver_perference))
+ LinearProjectionStrategyGenerator(
+ op_data_mapping,
+ self.device_mesh,
+ linear_projection_type="linear",
+ solver_perference=self.solver_perference,
+ )
+ )
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
@@ -163,28 +172,34 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# the strategies will be transformed back to its original shape in self.post_process
input_meta_data = self.node.args[0]._meta_data
input_logical_shape = input_meta_data.view(-1, input_meta_data.shape[-1]).shape
- physical_input_operand = OperationData(name=str(self.node.args[0]),
- type=OperationDataType.ARG,
- data=input_meta_data,
- logical_shape=input_logical_shape)
- physical_other_operand = OperationData(name="weight",
- type=OperationDataType.PARAM,
- data=self.named_parameters['weight'],
- logical_shape=self.named_parameters['weight'].shape[::-1])
+ physical_input_operand = OperationData(
+ name=str(self.node.args[0]),
+ type=OperationDataType.ARG,
+ data=input_meta_data,
+ logical_shape=input_logical_shape,
+ )
+ physical_other_operand = OperationData(
+ name="weight",
+ type=OperationDataType.PARAM,
+ data=self.named_parameters["weight"],
+ logical_shape=self.named_parameters["weight"].shape[::-1],
+ )
output_meta_data = self.node._meta_data
output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape
- physical_output = OperationData(name=str(self.node),
- type=OperationDataType.OUTPUT,
- data=output_meta_data,
- logical_shape=output_logical_shape)
+ physical_output = OperationData(
+ name=str(self.node),
+ type=OperationDataType.OUTPUT,
+ data=output_meta_data,
+ logical_shape=output_logical_shape,
+ )
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
- if 'bias' in self.named_parameters is not None:
- physical_bias_operand = OperationData(name="bias",
- type=OperationDataType.PARAM,
- data=self.named_parameters['bias'])
- mapping['bias'] = physical_bias_operand
+ if "bias" in self.named_parameters is not None:
+ physical_bias_operand = OperationData(
+ name="bias", type=OperationDataType.PARAM, data=self.named_parameters["bias"]
+ )
+ mapping["bias"] = physical_bias_operand
return mapping
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
@@ -194,14 +209,14 @@ def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, Li
2. the input and output sharding specs are updated to physical shape.
"""
# switch the dimensions of the transposed weight
- strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy, weight_name='weight')
+ strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy, weight_name="weight")
# create multiple sharding strategies for the inputs
# as input can be multi-dimensional and the partition dim is only 2D,
# we need to map the partition at dim 0 to one of the first few dimensions of the input
- strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy=strategy,
- input_name=str(self.node.args[0]),
- output_name=str(self.node))
+ strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear(
+ strategy=strategy, input_name=str(self.node.args[0]), output_name=str(self.node)
+ )
return strategies
@@ -215,7 +230,8 @@ def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(
- LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear'))
+ LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type="linear")
+ )
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
@@ -223,10 +239,12 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# the strategies will be transformed back to its original shape in self.post_process
input_meta_data = self.node.args[0]._meta_data
input_logical_shape = input_meta_data.view(-1, input_meta_data.shape[-1]).shape
- physical_input_operand = OperationData(name=str(self.node.args[0]),
- type=OperationDataType.ARG,
- data=self.node.args[0]._meta_data,
- logical_shape=input_logical_shape)
+ physical_input_operand = OperationData(
+ name=str(self.node.args[0]),
+ type=OperationDataType.ARG,
+ data=self.node.args[0]._meta_data,
+ logical_shape=input_logical_shape,
+ )
# check if the other operand is a parameter
if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter):
@@ -234,10 +252,12 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
else:
data_type = OperationDataType.ARG
- physical_other_operand = OperationData(name=str(self.node.args[1]),
- type=data_type,
- data=self.node.args[1]._meta_data,
- logical_shape=self.node.args[1]._meta_data.shape[::-1])
+ physical_other_operand = OperationData(
+ name=str(self.node.args[1]),
+ type=data_type,
+ data=self.node.args[1]._meta_data,
+ logical_shape=self.node.args[1]._meta_data.shape[::-1],
+ )
output_meta_data = self.node._meta_data
output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape
physical_output = OperationData(
@@ -249,27 +269,28 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
- if 'bias' in self.node.kwargs and self.node.kwargs['bias'] is not None:
+ if "bias" in self.node.kwargs and self.node.kwargs["bias"] is not None:
# check if the other operand is a parameter
if isinstance(self.node.kwargs["bias"]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
- physical_bias_operand = OperationData(name=str(self.node.kwargs["bias"]),
- type=data_type,
- data=self.node.kwargs["bias"]._meta_data)
- mapping['bias'] = physical_bias_operand
+ physical_bias_operand = OperationData(
+ name=str(self.node.kwargs["bias"]), type=data_type, data=self.node.kwargs["bias"]._meta_data
+ )
+ mapping["bias"] = physical_bias_operand
return mapping
def post_process(self, strategy: ShardingStrategy):
# switch the dimensions of the transposed weight
- strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy,
- weight_name=str(self.node.args[1]))
+ strategy = _update_sharding_spec_for_transposed_weight_for_linear(
+ strategy=strategy, weight_name=str(self.node.args[1])
+ )
# create multiple sharding strategies for the inputs
# as input can be multi-dimensional and the partition dim is only 2D,
# we need to map the partition at dim 0 to one of the first few dimensions of the input
- strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy=strategy,
- input_name=str(self.node.args[0]),
- output_name=str(self.node))
+ strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear(
+ strategy=strategy, input_name=str(self.node.args[0]), output_name=str(self.node)
+ )
return strategies
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py
index fa51114a5c94..4fab5f7f05eb 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py
@@ -16,7 +16,7 @@
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from ..utils import recover_sharding_spec_for_broadcast_shape
-from .node_handler import MetaInfoNodeHandler, NodeHandler
+from .node_handler import MetaInfoNodeHandler
from .registry import operator_registry
from .strategy import (
BatchedMatMulStrategyGenerator,
@@ -37,6 +37,7 @@ class MatMulType(Enum):
MV: matrix-vector product: the 1st tensor is 2D and the 2nd tensor is 1D
BMM: batched matrix-matrix multiplication, one tensor is at least 1D and the other is at least 3D
"""
+
DOT = 0
MM = 1
MV = 2
@@ -92,26 +93,26 @@ def __init__(self) -> None:
def apply(self, shape_mapping: Dict[str, List[int]]):
mapping_copy = deepcopy(shape_mapping)
- input_shape = mapping_copy['input']
- other_shape = mapping_copy['other']
+ input_shape = mapping_copy["input"]
+ other_shape = mapping_copy["other"]
if len(input_shape) == 1:
# if the input is a 1D tensor, 1 is prepended to its shape
# and it will be removed afterwards
input_shape.insert(0, 1)
- self.padded_dim_mapping['input'] = -2
- self.padded_dim_mapping['output'] = -2
+ self.padded_dim_mapping["input"] = -2
+ self.padded_dim_mapping["output"] = -2
elif len(other_shape) == 1:
# if the other is a 1D tensor, 1 is appended to its shape
# and it will be removed afterwards
other_shape = other_shape.append(1)
- self.padded_dim_mapping['other'] = -1
- self.padded_dim_mapping['output'] = -1
+ self.padded_dim_mapping["other"] = -1
+ self.padded_dim_mapping["output"] = -1
return mapping_copy
def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy):
- input_op_data = op_data_mapping['input']
- other_op_data = op_data_mapping['other']
+ op_data_mapping["input"]
+ op_data_mapping["other"]
def _remove_padded_dim(key, strategy):
op_data = op_data_mapping[key]
@@ -131,7 +132,7 @@ def _remove_padded_dim(key, strategy):
# compute unpadded tensor shape
tensor_shape.pop(padded_dim)
- assert tensor_shape == list(op_data.data.shape), f'{tensor_shape} vs {list(op_data.data.shape)}'
+ assert tensor_shape == list(op_data.data.shape), f"{tensor_shape} vs {list(op_data.data.shape)}"
# update sharding spec
sharding_spec.__init__(sharding_spec.device_mesh, tensor_shape, unpadded_dim_partition_list)
@@ -142,15 +143,15 @@ def _remove_padded_dim(key, strategy):
strategy_copy = strategy.clone()
# only one of input and other will be padded
- if 'input' in self.padded_dim_mapping:
- _remove_padded_dim('input', strategy_copy)
- _remove_padded_dim('output', strategy_copy)
- elif 'other' in self.padded_dim_mapping:
- _remove_padded_dim('other', strategy_copy)
- _remove_padded_dim('output', strategy_copy)
+ if "input" in self.padded_dim_mapping:
+ _remove_padded_dim("input", strategy_copy)
+ _remove_padded_dim("output", strategy_copy)
+ elif "other" in self.padded_dim_mapping:
+ _remove_padded_dim("other", strategy_copy)
+ _remove_padded_dim("output", strategy_copy)
strategies.append(strategy_copy)
- except ShardingSpecException as e:
+ except ShardingSpecException:
pass
return strategies
@@ -167,8 +168,8 @@ def apply(self, shape_mapping: Dict[str, List[int]]):
mapping_copy = shape_mapping.copy()
# get shapes
- input_shape = mapping_copy['input']
- other_shape = mapping_copy['other']
+ input_shape = mapping_copy["input"]
+ other_shape = mapping_copy["other"]
# sanity check
assert len(input_shape) > 1 and len(other_shape) > 1
@@ -179,16 +180,16 @@ def apply(self, shape_mapping: Dict[str, List[int]]):
# store the broadcast dim info
input_broadcast_dim_info = get_broadcast_dim_info(bcast_non_matrix_dims, input_shape[:-2])
other_broadcast_dim_info = get_broadcast_dim_info(bcast_non_matrix_dims, other_shape[:-2])
- self.broadcast_dim_info['input'] = input_broadcast_dim_info
- self.broadcast_dim_info['other'] = other_broadcast_dim_info
+ self.broadcast_dim_info["input"] = input_broadcast_dim_info
+ self.broadcast_dim_info["other"] = other_broadcast_dim_info
# create the full logical shape
input_shape = bcast_non_matrix_dims + input_shape[-2:]
other_shape = bcast_non_matrix_dims + other_shape[-2:]
assert len(input_shape) == len(other_shape)
- mapping_copy['input'] = input_shape
- mapping_copy['other'] = other_shape
+ mapping_copy["input"] = input_shape
+ mapping_copy["other"] = other_shape
return mapping_copy
@@ -216,17 +217,18 @@ def _remove_sharding_on_broadcast_dim(key, strategy):
physical_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
logical_sharding_spec=sharding_spec,
logical_shape=sharding_spec.entire_shape,
- physical_shape=tensor_shape_before_broadcast)
+ physical_shape=tensor_shape_before_broadcast,
+ )
strategy.sharding_specs[op_data] = physical_sharding_spec
# enumerate all sharding strategies
strategies = []
try:
strategy_copy = strategy.clone()
- _remove_sharding_on_broadcast_dim('input', strategy_copy)
- _remove_sharding_on_broadcast_dim('other', strategy_copy)
+ _remove_sharding_on_broadcast_dim("input", strategy_copy)
+ _remove_sharding_on_broadcast_dim("other", strategy_copy)
strategies.append(strategy_copy)
- except ShardingSpecException as e:
+ except ShardingSpecException:
pass
return strategies
@@ -241,20 +243,20 @@ def __init__(self) -> None:
def apply(self, shape_mapping: Dict[str, List[int]]):
mapping_copy = shape_mapping.copy()
- self.batch_dims_before_view = list(mapping_copy['input'][:-2])
+ self.batch_dims_before_view = list(mapping_copy["input"][:-2])
# get shapes
- input_shape = shape_mapping['input']
- other_shape = shape_mapping['other']
+ input_shape = shape_mapping["input"]
+ other_shape = shape_mapping["other"]
# view to 3d tensor
assert len(input_shape) >= 3 and len(other_shape) >= 3
input_shape = [reduce(operator.mul, input_shape[:-2])] + input_shape[-2:]
other_shape = [reduce(operator.mul, other_shape[:-2])] + other_shape[-2:]
output_shape = input_shape[:2] + other_shape[2:]
- mapping_copy['input'] = input_shape
- mapping_copy['other'] = other_shape
- mapping_copy['output'] = output_shape
+ mapping_copy["input"] = input_shape
+ mapping_copy["other"] = other_shape
+ mapping_copy["output"] = output_shape
return mapping_copy
def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy):
@@ -291,11 +293,11 @@ def _update_sharding_spec(key, strategy, physical_batch_dim):
# create a new strategy
strategy_copy = strategy.clone()
try:
- _update_sharding_spec('input', strategy_copy, i)
- _update_sharding_spec('other', strategy_copy, i)
- _update_sharding_spec('output', strategy_copy, i)
+ _update_sharding_spec("input", strategy_copy, i)
+ _update_sharding_spec("other", strategy_copy, i)
+ _update_sharding_spec("output", strategy_copy, i)
strategies.append(strategy_copy)
- except ShardingSpecException as e:
+ except ShardingSpecException:
continue
return strategies
@@ -312,14 +314,14 @@ def _get_bmm_logical_shape(input_shape, other_shape, transforms):
3. reshape to 3 dimensions
"""
- shape_mapping = {'input': input_shape, 'other': other_shape}
+ shape_mapping = {"input": input_shape, "other": other_shape}
for transform in transforms:
shape_mapping = transform.apply(shape_mapping)
- input_shape = shape_mapping.get('input', None)
- other_shape = shape_mapping.get('other', None)
- output_shape = shape_mapping.get('output', None)
+ input_shape = shape_mapping.get("input", None)
+ other_shape = shape_mapping.get("other", None)
+ output_shape = shape_mapping.get("output", None)
return input_shape, other_shape, output_shape
@@ -364,7 +366,8 @@ def get_strategy_generator(self) -> List[StrategyGenerator]:
generators.append(MatVecStrategyGenerator(op_data_mapping, self.device_mesh))
elif self.matmul_type == MatMulType.MM:
generators.append(
- LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear'))
+ LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type="linear")
+ )
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
@@ -372,7 +375,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
MatMulType.DOT: self._get_logical_shape_for_dot,
MatMulType.MM: self._get_logical_shape_for_mm,
MatMulType.MV: self._get_logical_shape_for_mv,
- MatMulType.BMM: self._get_logical_shape_for_bmm
+ MatMulType.BMM: self._get_logical_shape_for_bmm,
}
logical_shapes = logical_shape_func[self.matmul_type]()
op_data_mapping = self._get_op_data_mapping(*logical_shapes)
@@ -390,20 +393,26 @@ def _get_op_data_mapping(self, input_logical_shape, other_logical_shape, output_
output_logical_shape = torch.Size(output_logical_shape)
# create op data
- input_op_data = OperationData(name=str(self.node.args[0]),
- type=OperationDataType.ARG,
- data=self.input_meta_data,
- logical_shape=input_logical_shape)
- other_op_data = OperationData(name=str(self.node.args[1]),
- type=OperationDataType.ARG,
- data=self.other_meta_data,
- logical_shape=other_logical_shape)
- output_op_data = OperationData(name=str(self.node),
- type=OperationDataType.OUTPUT,
- data=self.output_meta_data,
- logical_shape=output_logical_shape)
-
- mapping = {'input': input_op_data, 'other': other_op_data, 'output': output_op_data}
+ input_op_data = OperationData(
+ name=str(self.node.args[0]),
+ type=OperationDataType.ARG,
+ data=self.input_meta_data,
+ logical_shape=input_logical_shape,
+ )
+ other_op_data = OperationData(
+ name=str(self.node.args[1]),
+ type=OperationDataType.ARG,
+ data=self.other_meta_data,
+ logical_shape=other_logical_shape,
+ )
+ output_op_data = OperationData(
+ name=str(self.node),
+ type=OperationDataType.OUTPUT,
+ data=self.output_meta_data,
+ logical_shape=output_logical_shape,
+ )
+
+ mapping = {"input": input_op_data, "other": other_op_data, "output": output_op_data}
return mapping
def _get_logical_shape_for_dot(self):
@@ -460,9 +469,11 @@ def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, Li
dim_partition_dict[0] = shard
# re-init the sharding spec
- input_sharding_spec.__init__(input_sharding_spec.device_mesh,
- entire_shape=input_physical_shape,
- dim_partition_dict=dim_partition_dict)
+ input_sharding_spec.__init__(
+ input_sharding_spec.device_mesh,
+ entire_shape=input_physical_shape,
+ dim_partition_dict=dim_partition_dict,
+ )
return strategy
else:
return strategy
@@ -481,7 +492,8 @@ def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, Li
recovered_stragies.extend(output)
else:
raise TypeError(
- f"Found unexpected output type {type(output)} from the recover method of BmmTransform")
+ f"Found unexpected output type {type(output)} from the recover method of BmmTransform"
+ )
strategies = recovered_stragies
for index, strategies in enumerate(strategies):
strategies.name = f"{strategies.name}_{index}"
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py
index b4b7b0e794d1..d2bad39dcbb9 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py
@@ -8,7 +8,6 @@
from colossalai.auto_parallel.tensor_shard.options import ShardOption, SolverPerference
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData,
- OperationDataType,
ShardingSpec,
ShardingStrategy,
StrategiesVector,
@@ -23,21 +22,23 @@
class NodeHandler(ABC):
- '''
+ """
The NodeHandler is an abstract class used to generate every possible strategies for an operator node.
Args:
node (Node): the input node in node argument list.
device_mesh (DeviceMesh): A logical view of a physical mesh.
strategies_vector (StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector.
- '''
-
- def __init__(self,
- node: Node,
- device_mesh: DeviceMesh,
- strategies_vector: StrategiesVector,
- shard_option: ShardOption = ShardOption.STANDARD,
- solver_perference: SolverPerference = SolverPerference.STANDARD) -> None:
+ """
+
+ def __init__(
+ self,
+ node: Node,
+ device_mesh: DeviceMesh,
+ strategies_vector: StrategiesVector,
+ shard_option: ShardOption = ShardOption.STANDARD,
+ solver_perference: SolverPerference = SolverPerference.STANDARD,
+ ) -> None:
self.node = node
self.predecessor_node = list(node._input_nodes.keys())
self.successor_node = list(node.users.keys())
@@ -68,8 +69,9 @@ def update_resharding_cost(self, strategy: ShardingStrategy) -> None:
current_sharding_spec = strategy.sharding_specs[op_data]
# get the sharding specs for this node generated
# in its own node handler
- assert hasattr(node, 'strategies_vector'), \
- f'The predecessor node {node_name} has no strategy vector to compute the resharding cost.'
+ assert hasattr(
+ node, "strategies_vector"
+ ), f"The predecessor node {node_name} has no strategy vector to compute the resharding cost."
prev_strategy_vector = node.strategies_vector
prev_sharding_specs = [
prev_strategy.get_sharding_spec_by_name(node_name) for prev_strategy in prev_strategy_vector
@@ -80,10 +82,10 @@ def update_resharding_cost(self, strategy: ShardingStrategy) -> None:
resharding_costs[node] = []
def _compute_resharding_cost(
- prev_sharding_spec: Union[ShardingSpec,
- List[ShardingSpec]], current_sharding_spec: Union[ShardingSpec,
- List[ShardingSpec]],
- data: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]) -> TrainCycleItem:
+ prev_sharding_spec: Union[ShardingSpec, List[ShardingSpec]],
+ current_sharding_spec: Union[ShardingSpec, List[ShardingSpec]],
+ data: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
+ ) -> TrainCycleItem:
"""
This is a helper function to compute the resharding cost for a specific strategy of a node.
"""
@@ -94,30 +96,35 @@ def _compute_resharding_cost(
dtype = data.dtype
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
_, _, consistency_cost = shape_consistency_manager.shape_consistency(
- prev_sharding_spec, current_sharding_spec)
-
- resharding_cost = TrainCycleItem(fwd=consistency_cost["forward"] * size_per_elem_bytes,
- bwd=consistency_cost["backward"] * size_per_elem_bytes,
- total=consistency_cost["total"] * size_per_elem_bytes)
+ prev_sharding_spec, current_sharding_spec
+ )
+
+ resharding_cost = TrainCycleItem(
+ fwd=consistency_cost["forward"] * size_per_elem_bytes,
+ bwd=consistency_cost["backward"] * size_per_elem_bytes,
+ total=consistency_cost["total"] * size_per_elem_bytes,
+ )
return resharding_cost
else:
# This raise is used to check if we have missed any type of data.
# It could be merged into Parameter branch, which means we won't handle
# non-tensor arguments.
- raise ValueError(f'Unsupported data type {type(data)}')
+ raise ValueError(f"Unsupported data type {type(data)}")
else:
- assert isinstance(prev_sharding_spec, (tuple, list)), \
- f'prev_sharding_spec should be in type of ShardingSpec, List[ShardingSpec], \
- or Tuple[ShardingSpec], but got {type(prev_sharding_spec)}'
+ assert isinstance(
+ prev_sharding_spec, (tuple, list)
+ ), f"prev_sharding_spec should be in type of ShardingSpec, List[ShardingSpec], \
+ or Tuple[ShardingSpec], but got {type(prev_sharding_spec)}"
fwd_cost = 0
bwd_cost = 0
total_cost = 0
- for index, (prev_sharding_spec_item,
- current_sharding_spec_item) in enumerate(zip(prev_sharding_spec,
- current_sharding_spec)):
- item_cost = _compute_resharding_cost(prev_sharding_spec_item, current_sharding_spec_item,
- data[index])
+ for index, (prev_sharding_spec_item, current_sharding_spec_item) in enumerate(
+ zip(prev_sharding_spec, current_sharding_spec)
+ ):
+ item_cost = _compute_resharding_cost(
+ prev_sharding_spec_item, current_sharding_spec_item, data[index]
+ )
fwd_cost += item_cost.fwd
bwd_cost += item_cost.bwd
total_cost += item_cost.total
@@ -138,17 +145,17 @@ def get_target_function(self) -> callable:
This function is used to get the target function for the node handler.
The target function is used to analyze the costs of strategies.
"""
- if self.node.op in ('placeholder', 'get_attr', 'output'):
+ if self.node.op in ("placeholder", "get_attr", "output"):
return None
- if self.node.op == 'call_module':
+ if self.node.op == "call_module":
target = self.node.graph.owning_module.get_submodule(self.node.target)
- elif self.node.op == 'call_function':
+ elif self.node.op == "call_function":
target = self.node.target
- elif self.node.op == 'call_method':
+ elif self.node.op == "call_method":
target = getattr(self.node.args[0]._meta_data.__class__, self.node.target)
else:
- raise ValueError(f'Unsupported node type: {self.node.op}')
+ raise ValueError(f"Unsupported node type: {self.node.op}")
return target
@@ -221,7 +228,6 @@ def get_strategy_generator(self) -> List[StrategyGenerator]:
"""
Define which generators should be used by this NodeHandler object.
"""
- pass
@abstractmethod
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
@@ -244,7 +250,6 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
"output": Operand(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data),
}
"""
- pass
class MetaInfoNodeHandler(NodeHandler):
@@ -278,19 +283,19 @@ def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesV
else:
logger = get_dist_logger()
- logger.warning(f'The target function {target} is not patched yet, ')
+ logger.warning(f"The target function {target} is not patched yet, ")
return self.strategies_vector
class ModuleHandler(NodeHandler):
-
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
# set attributes to access module parameters for convenience
- assert self.node.graph.owning_module is not None, \
- f'The graph is not associated with a module, please make sure it can be used to instantiate a GraphModule object.'
+ assert (
+ self.node.graph.owning_module is not None
+ ), f"The graph is not associated with a module, please make sure it can be used to instantiate a GraphModule object."
module = self.node.graph.owning_module.get_submodule(self.node.target)
named_parameters = list(module.named_parameters(recurse=False))
named_buffers = list(module.named_buffers(recurse=False))
@@ -333,6 +338,6 @@ def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesV
else:
logger = get_dist_logger()
- logger.warning(f'The target function {target} is not patched yet')
+ logger.warning(f"The target function {target} is not patched yet")
return self.strategies_vector
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/normal_pooling_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/normal_pooling_handler.py
index 4e71ccba95a7..facf19560596 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/normal_pooling_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/normal_pooling_handler.py
@@ -3,11 +3,11 @@
import torch
from ..sharding_strategy import OperationData, OperationDataType
-from .node_handler import MetaInfoModuleHandler, ModuleHandler
+from .node_handler import MetaInfoModuleHandler
from .registry import operator_registry
from .strategy import NormalPoolStrategyGenerator, StrategyGenerator
-__all__ = ['NormPoolingHandler']
+__all__ = ["NormPoolingHandler"]
@operator_registry.register(torch.nn.MaxPool1d)
@@ -30,9 +30,9 @@ def get_strategy_generator(self) -> List[StrategyGenerator]:
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
- physical_input_operand = OperationData(name=str(self.node.args[0]),
- type=OperationDataType.ARG,
- data=self.node.args[0]._meta_data)
+ physical_input_operand = OperationData(
+ name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data
+ )
physical_weight_operand = OperationData(name="kernel", type=OperationDataType.ARG, data=self.module.kernel_size)
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/output_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/output_handler.py
index ed120a8c3d6d..89906a205e87 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/output_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/output_handler.py
@@ -8,7 +8,7 @@
from .node_handler import NodeHandler
from .strategy import OutputGenerator, StrategyGenerator
-__all__ = ['OutputHandler']
+__all__ = ["OutputHandler"]
class OutputHandler(NodeHandler):
@@ -16,8 +16,9 @@ class OutputHandler(NodeHandler):
A OutputHandler which deals with the sharding strategies for Output Node.
"""
- def __init__(self, node: torch.fx.node.Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector,
- output_option: str) -> None:
+ def __init__(
+ self, node: torch.fx.node.Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector, output_option: str
+ ) -> None:
super().__init__(node, device_mesh, strategies_vector)
self.output_option = output_option
@@ -35,11 +36,11 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
for index, input_node in enumerate(self.predecessor_node):
input_meta_data = input_node._meta_data
physical_inputs = OperationData(name=str(input_node), type=OperationDataType.ARG, data=input_meta_data)
- name_key = f'input_{index}'
+ name_key = f"input_{index}"
mapping[name_key] = physical_inputs
output_meta_data.append(input_meta_data)
- assert len(output_meta_data) > 0, f'Output node {self.node} has no input node.'
+ assert len(output_meta_data) > 0, f"Output node {self.node} has no input node."
if len(output_meta_data) == 1:
output_meta_data = output_meta_data[0]
else:
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/permute_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/permute_handler.py
index 91e4a5105a08..75f07168e47b 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/permute_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/permute_handler.py
@@ -7,7 +7,7 @@
from .registry import operator_registry
from .strategy import PermuteGenerator, StrategyGenerator
-__all__ = ['PermuteHandler']
+__all__ = ["PermuteHandler"]
@operator_registry.register(torch.Tensor.permute)
@@ -34,14 +34,14 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data)
permute_dims = []
- if self.node.op == 'call_method':
+ if self.node.op == "call_method":
# torch.Tensor.permute (input, *dims)
for arg in self.node.args:
if isinstance(arg, torch.fx.Node):
if isinstance(arg._meta_data, int):
permute_dims.append(arg._meta_data)
else:
- assert isinstance(arg, int), 'The argument in permute node should be either type of Node or int.'
+ assert isinstance(arg, int), "The argument in permute node should be either type of Node or int."
permute_dims.append(arg)
else:
# torch.permute (input, dims)
@@ -51,8 +51,8 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
permute_dims.extend(arg._meta_data)
else:
assert isinstance(
- arg,
- (tuple, list)), 'The argument in permute node should be type of Node, Tuple[int] or List[int].'
+ arg, (tuple, list)
+ ), "The argument in permute node should be type of Node, Tuple[int] or List[int]."
permute_dims.extend(arg)
num_dims = self.node._meta_data.dim()
@@ -61,7 +61,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
if permute_dims[i] < 0:
permute_dims[i] += num_dims
- physical_shape_operand = OperationData(name='permute_dims', type=OperationDataType.ARG, data=list(permute_dims))
+ physical_shape_operand = OperationData(name="permute_dims", type=OperationDataType.ARG, data=list(permute_dims))
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
@@ -69,7 +69,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
mapping = {
"input": physical_input_operand,
"permute_dims": physical_shape_operand,
- "output": physical_output_operand
+ "output": physical_output_operand,
}
return mapping
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/placeholder_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/placeholder_handler.py
index e4f40fc935a4..461bc2935780 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/placeholder_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/placeholder_handler.py
@@ -8,7 +8,7 @@
from .node_handler import NodeHandler
from .strategy import PlaceholderGenerator, StrategyGenerator
-__all__ = ['PlaceholderHandler']
+__all__ = ["PlaceholderHandler"]
class PlaceholderHandler(NodeHandler):
@@ -16,8 +16,9 @@ class PlaceholderHandler(NodeHandler):
A PlaceholderHandler which deals with the sharding strategies for Placeholder Node.
"""
- def __init__(self, node: Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector,
- placeholder_option: str) -> None:
+ def __init__(
+ self, node: Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector, placeholder_option: str
+ ) -> None:
super().__init__(node, device_mesh, strategies_vector)
self.placeholder_option = placeholder_option
@@ -25,7 +26,8 @@ def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(
- PlaceholderGenerator(op_data_mapping, self.device_mesh, placeholder_option=self.placeholder_option))
+ PlaceholderGenerator(op_data_mapping, self.device_mesh, placeholder_option=self.placeholder_option)
+ )
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/registry.py b/colossalai/auto_parallel/tensor_shard/node_handler/registry.py
index 730a90d74cf8..f663fc9695d3 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/registry.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/registry.py
@@ -1,11 +1,9 @@
class Registry:
-
def __init__(self, name):
self.name = name
self.store = {}
def register(self, source):
-
def wrapper(func):
if isinstance(source, (list, tuple)):
# support register a list of items for this func
@@ -18,7 +16,7 @@ def wrapper(func):
return wrapper
def get(self, source):
- assert source in self.store, f'{source} not found in the {self.name} registry'
+ assert source in self.store, f"{source} not found in the {self.name} registry"
target = self.store[source]
return target
@@ -26,4 +24,4 @@ def has(self, source):
return source in self.store
-operator_registry = Registry('operator')
+operator_registry = Registry("operator")
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/softmax_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/softmax_handler.py
index 743a1f90eaaf..6e883ea64736 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/softmax_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/softmax_handler.py
@@ -7,7 +7,7 @@
from .registry import operator_registry
from .strategy import SoftmaxGenerator, StrategyGenerator
-__all__ = ['SoftmaxHandler']
+__all__ = ["SoftmaxHandler"]
@operator_registry.register(torch.nn.Softmax)
@@ -34,14 +34,14 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
input_data = self.node.args[0]._meta_data
physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data)
- softmax_dim = self.node.kwargs['dim']
+ softmax_dim = self.node.kwargs["dim"]
num_dims = self.node.args[0]._meta_data.dim()
# recover negative value to positive
if softmax_dim < 0:
softmax_dim += num_dims
- physical_dim_operand = OperationData(name='softmax_dim', type=OperationDataType.ARG, data=softmax_dim)
+ physical_dim_operand = OperationData(name="softmax_dim", type=OperationDataType.ARG, data=softmax_dim)
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
@@ -49,7 +49,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
mapping = {
"input": physical_input_operand,
"softmax_dim": physical_dim_operand,
- "output": physical_output_operand
+ "output": physical_output_operand,
}
return mapping
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/split_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/split_handler.py
index 653d158b7c36..4c32529a5d5b 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/split_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/split_handler.py
@@ -7,7 +7,7 @@
from .registry import operator_registry
from .strategy import SplitGenerator, StrategyGenerator
-__all__ = ['SplitHandler']
+__all__ = ["SplitHandler"]
@operator_registry.register(torch.Tensor.split)
@@ -38,7 +38,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
split_dim = self.node.args[2]
else:
if self.node.kwargs:
- split_dim = self.node.kwargs['dim']
+ split_dim = self.node.kwargs["dim"]
else:
split_dim = 0
@@ -48,7 +48,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
split_dim += num_dims
split_info = (split_size, split_dim)
- physical_shape_operand = OperationData(name='split_info', type=OperationDataType.ARG, data=split_info)
+ physical_shape_operand = OperationData(name="split_info", type=OperationDataType.ARG, data=split_info)
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
@@ -56,7 +56,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
mapping = {
"input": physical_input_operand,
"split_info": physical_shape_operand,
- "output": physical_output_operand
+ "output": physical_output_operand,
}
return mapping
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py
index db1f31521c86..1fc7f613716b 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py
@@ -29,11 +29,31 @@
from .where_generator import WhereGenerator
__all__ = [
- 'StrategyGenerator', 'DotProductStrategyGenerator', 'MatVecStrategyGenerator', 'LinearProjectionStrategyGenerator',
- 'BatchedMatMulStrategyGenerator', 'ConvStrategyGenerator', 'UnaryElementwiseGenerator',
- 'BatchNormStrategyGenerator', 'GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator',
- 'LayerNormGenerator', 'PlaceholderGenerator', 'OutputGenerator', 'WhereGenerator', 'NormalPoolStrategyGenerator',
- 'BinaryElementwiseStrategyGenerator', 'GetattrGenerator', 'TensorConstructorGenerator',
- 'EmbeddingStrategyGenerator', 'SumGenerator', 'SoftmaxGenerator', 'ViewGenerator', 'PermuteGenerator',
- 'TransposeGenerator', 'SplitGenerator', 'DefaultReshapeGenerator'
+ "StrategyGenerator",
+ "DotProductStrategyGenerator",
+ "MatVecStrategyGenerator",
+ "LinearProjectionStrategyGenerator",
+ "BatchedMatMulStrategyGenerator",
+ "ConvStrategyGenerator",
+ "UnaryElementwiseGenerator",
+ "BatchNormStrategyGenerator",
+ "GetItemStrategyGenerator",
+ "TensorStrategyGenerator",
+ "TensorTupleStrategyGenerator",
+ "LayerNormGenerator",
+ "PlaceholderGenerator",
+ "OutputGenerator",
+ "WhereGenerator",
+ "NormalPoolStrategyGenerator",
+ "BinaryElementwiseStrategyGenerator",
+ "GetattrGenerator",
+ "TensorConstructorGenerator",
+ "EmbeddingStrategyGenerator",
+ "SumGenerator",
+ "SoftmaxGenerator",
+ "ViewGenerator",
+ "PermuteGenerator",
+ "TransposeGenerator",
+ "SplitGenerator",
+ "DefaultReshapeGenerator",
]
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py
index 416dc9c29cad..9c766b1014c8 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py
@@ -14,7 +14,7 @@
from .strategy_generator import StrategyGenerator
-__all__ = ['BatchNormStrategyGenerator']
+__all__ = ["BatchNormStrategyGenerator"]
class BatchNormStrategyGenerator(StrategyGenerator):
@@ -30,28 +30,31 @@ class BatchNormStrategyGenerator(StrategyGenerator):
"""
def validate(self) -> bool:
- '''
+ """
In sanity check, we need make sure the input data having correct dimension size.
For BatchNorm1d, the dim of input data should be 3([N, C, L]).
For BatchNorm2d, the dim of input data should be 4([N, C, H, W]).
For BatchNorm3d, the dim of input data should be 5([N, C, H, W, D]).
- '''
- input_op_data = self.op_data['input']
+ """
+ input_op_data = self.op_data["input"]
assert input_op_data.data.dim() in (
- 3, 4, 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
+ 3,
+ 4,
+ 5,
+ ), f"We suppose the dim of input fed into conv op should in range of [3, 5]."
def update_compute_cost(self, strategy: ShardingStrategy):
- '''
+ """
Compute the computation cost per device with this specific strategy.
Note: compute_cost need to be divided by TFLOPS, now it just shows the computation size.
- '''
+ """
# TODO: a constant coefficient need to be added.
# 1D: (L) * N * Cin
# 2D: (H * W) * N * Cin
# 3D: (H * W * D) * N * Cin
- sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
- sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
+ sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
+ sharded_output_shape = strategy.sharding_specs[self.op_data["output"]].get_sharded_shape_per_device()
if self.has_bias:
# bias add is an element wise operation, so the cost is equal to product of output shape.
bias_compute_cost = reduce(operator.mul, sharded_output_shape)
@@ -69,23 +72,24 @@ def update_compute_cost(self, strategy: ShardingStrategy):
def update_memory_cost(self, strategy: ShardingStrategy):
forward_size_mapping = {
- 'input': self._compute_size_in_bytes(strategy, "input"),
- 'other': self._compute_size_in_bytes(strategy, "other"),
- 'output': self._compute_size_in_bytes(strategy, "output"),
- 'running_mean': self._compute_size_in_bytes(strategy, "running_mean"),
- 'running_var': self._compute_size_in_bytes(strategy, "running_var"),
+ "input": self._compute_size_in_bytes(strategy, "input"),
+ "other": self._compute_size_in_bytes(strategy, "other"),
+ "output": self._compute_size_in_bytes(strategy, "output"),
+ "running_mean": self._compute_size_in_bytes(strategy, "running_mean"),
+ "running_var": self._compute_size_in_bytes(strategy, "running_var"),
}
if self.has_bias:
bias_size = self._compute_size_in_bytes(strategy, "bias")
- forward_size_mapping['bias'] = bias_size
+ forward_size_mapping["bias"] = bias_size
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
# compute fwd cost incurred
# fwd_cost = input + other + bias + output
fwd_activation_cost = sum(
- [v for k, v in forward_size_mapping.items() if not self.is_param(k) and not self.is_buffer(k)])
+ [v for k, v in forward_size_mapping.items() if not self.is_param(k) and not self.is_buffer(k)]
+ )
fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)])
fwd_buffer_cost = sum([v for k, v in forward_size_mapping.items() if self.is_buffer(k)])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost, buffer=fwd_buffer_cost)
@@ -93,36 +97,29 @@ def update_memory_cost(self, strategy: ShardingStrategy):
# compute bwd cost incurred
# bwd_cost = input_grad + other_grad + bias_grad
bwd_activation_cost = sum(
- [v for k, v in backward_size_mapping.items() if not self.is_param(k) and not self.is_buffer(k)])
+ [v for k, v in backward_size_mapping.items() if not self.is_param(k) and not self.is_buffer(k)]
+ )
bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
- total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
- parameter=fwd_parameter_cost + bwd_parameter_cost,
- buffer=fwd_buffer_cost)
+ total_mem_cost = MemoryCost(
+ activation=fwd_activation_cost + bwd_activation_cost,
+ parameter=fwd_parameter_cost + bwd_parameter_cost,
+ buffer=fwd_buffer_cost,
+ )
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@ignore_sharding_exception
def split_input_channel(self, mesh_dim_0):
- name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
+ name = f"RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}"
dim_partition_dict_mapping = {
- "input": {
- 1: [mesh_dim_0]
- },
- "other": {
- 0: [mesh_dim_0]
- },
- "output": {
- 1: [mesh_dim_0]
- },
- "running_mean": {
- 0: [mesh_dim_0]
- },
- "running_var": {
- 0: [mesh_dim_0]
- },
+ "input": {1: [mesh_dim_0]},
+ "other": {0: [mesh_dim_0]},
+ "output": {1: [mesh_dim_0]},
+ "running_mean": {0: [mesh_dim_0]},
+ "running_var": {0: [mesh_dim_0]},
"num_batches_tracked": {},
}
if self.has_bias:
@@ -132,29 +129,21 @@ def split_input_channel(self, mesh_dim_0):
communication_action_mapping = {}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1):
- name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}'
+ name = f"RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}"
dim_partition_dict_mapping = {
- "input": {
- 1: [mesh_dim_0, mesh_dim_1]
- },
- "other": {
- 0: [mesh_dim_0, mesh_dim_1]
- },
- "output": {
- 1: [mesh_dim_0, mesh_dim_1]
- },
- "running_mean": {
- 0: [mesh_dim_0, mesh_dim_1]
- },
- "running_var": {
- 0: [mesh_dim_0, mesh_dim_1]
- },
+ "input": {1: [mesh_dim_0, mesh_dim_1]},
+ "other": {0: [mesh_dim_0, mesh_dim_1]},
+ "output": {1: [mesh_dim_0, mesh_dim_1]},
+ "running_mean": {0: [mesh_dim_0, mesh_dim_1]},
+ "running_var": {0: [mesh_dim_0, mesh_dim_1]},
"num_batches_tracked": {},
}
if self.has_bias:
@@ -164,13 +153,15 @@ def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1):
communication_action_mapping = {}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def non_split(self):
- name = f'RR = RR x R'
+ name = f"RR = RR x R"
dim_partition_dict_mapping = {
"input": {},
"other": {},
@@ -186,21 +177,19 @@ def non_split(self):
communication_action_mapping = {}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_input_batch(self, mesh_dim_0):
- name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN'
+ name = f"S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN"
dim_partition_dict_mapping = {
- "input": {
- 0: [mesh_dim_0]
- },
+ "input": {0: [mesh_dim_0]},
"other": {},
- "output": {
- 0: [mesh_dim_0]
- },
+ "output": {0: [mesh_dim_0]},
"running_mean": {},
"running_var": {},
"num_batches_tracked": {},
@@ -218,27 +207,26 @@ def split_input_batch(self, mesh_dim_0):
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.IMPLICIT)
+ comm_type=CommType.IMPLICIT,
+ )
# TODO: Temporary solution has no communication cost,
# above action should be added after the SyncBN replace pass completed.
communication_action_mapping = {}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1):
- name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN'
+ name = f"S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN"
dim_partition_dict_mapping = {
- "input": {
- 0: [mesh_dim_0, mesh_dim_1]
- },
+ "input": {0: [mesh_dim_0, mesh_dim_1]},
"other": {},
- "output": {
- 0: [mesh_dim_0, mesh_dim_1]
- },
+ "output": {0: [mesh_dim_0, mesh_dim_1]},
"running_mean": {},
"running_var": {},
"num_batches_tracked": {},
@@ -256,19 +244,22 @@ def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1):
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
- comm_type=CommType.IMPLICIT)
+ comm_type=CommType.IMPLICIT,
+ )
# TODO: Temporary solution has no communication cost,
# above action should be added after the SyncBN replace pass completed.
communication_action_mapping = {}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_input_both_dim(self, mesh_dim_0, mesh_dim_1):
- name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN'
+ name = f"S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN"
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0],
@@ -304,20 +295,23 @@ def split_input_both_dim(self, mesh_dim_0, mesh_dim_1):
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0],
- comm_type=CommType.IMPLICIT)
+ comm_type=CommType.IMPLICIT,
+ )
# TODO: Temporary solution has no communication cost,
# above action should be added after the SyncBN replace pass completed.
communication_action_mapping = {}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
def collate_strategies(self) -> List[ShardingStrategy]:
- '''
+ """
Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector.
- '''
+ """
strategy_list = []
# RS = RS x S
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/binary_elementwise_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/binary_elementwise_generator.py
index d27cc046eaf3..c7da0034ec3b 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/binary_elementwise_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/binary_elementwise_generator.py
@@ -14,7 +14,7 @@
from .strategy_generator import StrategyGenerator
-__all__ = ['BinaryElementwiseStrategyGenerator']
+__all__ = ["BinaryElementwiseStrategyGenerator"]
class BinaryElementwiseStrategyGenerator(StrategyGenerator):
@@ -26,36 +26,37 @@ class BinaryElementwiseStrategyGenerator(StrategyGenerator):
"""
def validate(self) -> bool:
- assert len(self.op_data) == 3, \
- f'BinaryElementwiseStrategyGenerator only accepts three operation data (input, other and output), but got {len(self.op_data)}'
+ assert (
+ len(self.op_data) == 3
+ ), f"BinaryElementwiseStrategyGenerator only accepts three operation data (input, other and output), but got {len(self.op_data)}"
for name, op_data in self.op_data.items():
if not isinstance(op_data.data, (torch.Tensor, int, float)):
- raise TypeError(f'The operation data {name} is not a torch.Tensor/int/float.')
+ raise TypeError(f"The operation data {name} is not a torch.Tensor/int/float.")
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
- shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
+ shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
# since elementwise ops are not compute-intensive,
# we approximate the backward compute cost
# to be twice the fwd compute cost
fwd_compute_cost = reduce(operator.mul, shape)
bwd_compute_cost = fwd_compute_cost * 2
- compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
- bwd=bwd_compute_cost,
- total=fwd_compute_cost + bwd_compute_cost)
+ compute_cost = TrainCycleItem(
+ fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
+ )
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
# all input, output and outputs have the same shape
- shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
+ strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
# compute fwd memory cost in bytes
# as the elementwise ops are not memory-intensive
# we approximate the fwd memory cost to be the output
# and the backward memory cost to be grad of input and other
- input_bytes = self._compute_size_in_bytes(strategy, 'input')
- other_bytes = self._compute_size_in_bytes(strategy, 'other')
- output_bytes = self._compute_size_in_bytes(strategy, 'output')
+ input_bytes = self._compute_size_in_bytes(strategy, "input")
+ other_bytes = self._compute_size_in_bytes(strategy, "other")
+ output_bytes = self._compute_size_in_bytes(strategy, "output")
fwd_memory_cost = MemoryCost(activation=output_bytes)
bwd_memory_cost = MemoryCost(activation=input_bytes + other_bytes)
total_memory_cost = MemoryCost(activation=input_bytes + other_bytes + output_bytes)
@@ -66,7 +67,7 @@ def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
def enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1):
# we check for the output logical shape to get the number of dimensions
dim_partition_list = []
- dim_size = len(self.op_data['output'].logical_shape)
+ dim_size = len(self.op_data["output"].logical_shape)
# enumerate all the 2D sharding cases
sharding_list_2d = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size)
@@ -86,21 +87,22 @@ def enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1):
# convert these dim partition dict to sharding strategy
for dim_partition_dict in dim_partition_list:
- dim_partition_dict_mapping = dict(input=dim_partition_dict,
- other=dim_partition_dict,
- output=dim_partition_dict)
+ dim_partition_dict_mapping = dict(
+ input=dim_partition_dict, other=dim_partition_dict, output=dim_partition_dict
+ )
try:
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
communication_action_mapping = {}
# get name
- sharding_seq = sharding_spec_mapping['input'].sharding_sequence
- name = f'{sharding_seq} = {sharding_seq} {sharding_seq}'
+ sharding_seq = sharding_spec_mapping["input"].sharding_sequence
+ name = f"{sharding_seq} = {sharding_seq} {sharding_seq}"
sharding_strategy = self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ communication_action_mapping=communication_action_mapping,
+ )
strategy_list.append(sharding_strategy)
except ShardingSpecException:
continue
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py
index e605a68a326b..5208f61543bb 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py
@@ -1,11 +1,9 @@
import copy
import operator
-import warnings
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
- CommAction,
CommType,
MemoryCost,
ShardingStrategy,
@@ -24,29 +22,32 @@ class ConvStrategyGenerator(StrategyGenerator):
"""
def validate(self) -> bool:
- '''
+ """
In sanity check, we need make sure the input data having correct dimension size.
For Conv1d, the dim of input data should be 3([N, C, L]).
For Conv2d, the dim of input data should be 4([N, C, H, W]).
For Conv3d, the dim of input data should be 5([N, C, H, W, D]).
- '''
- input_op_data = self.op_data['input']
+ """
+ input_op_data = self.op_data["input"]
assert input_op_data.data.dim() in (
- 3, 4, 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
+ 3,
+ 4,
+ 5,
+ ), f"We suppose the dim of input fed into conv op should in range of [3, 5]."
def update_compute_cost(self, strategy: ShardingStrategy):
- '''
+ """
Compute the computation cost per device with this specific strategy.
Note: compute_cost need to be divided by TFLOPS, now it just shows the computation size.
- '''
+ """
# TODO: compute_cost need to be divided by TFLOPS, now it just shows the computation size.
# 1D: (L) * N * Cout * Cin * kernel
# 2D: (H * W) * N * Cout * Cin * kernel
# 3D: (H * W * D) * N * Cout * Cin * kernel
- sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
- sharded_other_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device()
- sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
+ sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
+ sharded_other_shape = strategy.sharding_specs[self.op_data["other"]].get_sharded_shape_per_device()
+ sharded_output_shape = strategy.sharding_specs[self.op_data["output"]].get_sharded_shape_per_device()
if self.has_bias:
# bias add is an element wise operation, so the cost is equal to product of output shape.
bias_compute_cost = reduce(operator.mul, sharded_output_shape)
@@ -76,14 +77,14 @@ def update_compute_cost(self, strategy: ShardingStrategy):
def update_memory_cost(self, strategy: ShardingStrategy):
forward_size_mapping = {
- 'input': self._compute_size_in_bytes(strategy, "input"),
- 'other': self._compute_size_in_bytes(strategy, "other"),
- 'output': self._compute_size_in_bytes(strategy, "output")
+ "input": self._compute_size_in_bytes(strategy, "input"),
+ "other": self._compute_size_in_bytes(strategy, "other"),
+ "output": self._compute_size_in_bytes(strategy, "output"),
}
if self.has_bias:
bias_size = self._compute_size_in_bytes(strategy, "bias")
- forward_size_mapping['bias'] = bias_size
+ forward_size_mapping["bias"] = bias_size
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
@@ -100,26 +101,20 @@ def update_memory_cost(self, strategy: ShardingStrategy):
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
- total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
- parameter=fwd_parameter_cost + bwd_parameter_cost)
+ total_mem_cost = MemoryCost(
+ activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost
+ )
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@ignore_sharding_exception
def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
- name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
+ name = f"S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}"
dim_partition_dict_mapping = {
- "input": {
- 0: [mesh_dim_0]
- },
- "other": {
- 1: [mesh_dim_1]
- },
- "output": {
- 0: [mesh_dim_0],
- 1: [mesh_dim_1]
- },
+ "input": {0: [mesh_dim_0]},
+ "other": {1: [mesh_dim_1]},
+ "output": {0: [mesh_dim_0], 1: [mesh_dim_1]},
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {0: [mesh_dim_1]}
@@ -132,7 +127,8 @@ def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
- arg_index=0)
+ arg_index=0,
+ )
communication_action_mapping = {"input": input_comm_action}
if self.is_param("other"):
@@ -140,7 +136,8 @@ def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
other_comm_action = self.get_communication_action(
@@ -148,38 +145,41 @@ def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
- arg_index=1)
+ arg_index=1,
+ )
communication_action_mapping["other"] = other_comm_action
if self.has_bias:
- if self.is_param('bias'):
+ if self.is_param("bias"):
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
- key_for_kwarg='bias')
+ key_for_kwarg="bias",
+ )
communication_action_mapping["bias"] = bias_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_input_batch(self, mesh_dim_0):
- name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR'
+ name = f"S{mesh_dim_0}R = S{mesh_dim_0}R x RR"
dim_partition_dict_mapping = {
- "input": {
- 0: [mesh_dim_0]
- },
+ "input": {0: [mesh_dim_0]},
"other": {},
"output": {
0: [mesh_dim_0],
@@ -196,7 +196,8 @@ def split_input_batch(self, mesh_dim_0):
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
other_comm_action = self.get_communication_action(
@@ -204,42 +205,45 @@ def split_input_batch(self, mesh_dim_0):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
- arg_index=1)
+ arg_index=1,
+ )
communication_action_mapping["other"] = other_comm_action
if self.has_bias:
- if self.is_param('bias'):
+ if self.is_param("bias"):
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
- key_for_kwarg='bias')
+ key_for_kwarg="bias",
+ )
communication_action_mapping["bias"] = bias_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1):
- name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
+ name = f"S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R"
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0],
1: [mesh_dim_1],
},
- "other": {
- 0: [mesh_dim_1]
- },
+ "other": {0: [mesh_dim_1]},
"output": {
0: [mesh_dim_0],
},
@@ -254,7 +258,8 @@ def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1):
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1,
- comm_type=CommType.AFTER)
+ comm_type=CommType.AFTER,
+ )
communication_action_mapping = {"output": output_comm_action}
@@ -263,7 +268,8 @@ def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1):
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
other_comm_action = self.get_communication_action(
@@ -271,7 +277,8 @@ def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
- arg_index=1)
+ arg_index=1,
+ )
communication_action_mapping["other"] = other_comm_action
if self.has_bias:
if self.is_param("bias"):
@@ -279,23 +286,27 @@ def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1):
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
- key_for_kwarg='bias')
+ key_for_kwarg="bias",
+ )
communication_action_mapping["bias"] = bias_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1):
- name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
+ name = f"RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}"
dim_partition_dict_mapping = {
"input": {
@@ -322,23 +333,27 @@ def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1):
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.AFTER)
+ comm_type=CommType.AFTER,
+ )
input_comm_action = self.get_communication_action(
sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
- arg_index=0)
+ arg_index=0,
+ )
communication_action_mapping = {"output": output_comm_action, "input": input_comm_action}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_input_in_channel_weight_in_channel(self, mesh_dim_0):
- name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R'
+ name = f"RR = RS{mesh_dim_0} x S{mesh_dim_0}R"
dim_partition_dict_mapping = {
"input": {
@@ -360,17 +375,20 @@ def split_input_in_channel_weight_in_channel(self, mesh_dim_0):
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.AFTER)
+ comm_type=CommType.AFTER,
+ )
communication_action_mapping = {"output": output_comm_action}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_weight_out_channel(self, mesh_dim_0):
- name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}'
+ name = f"RS{mesh_dim_0} = RR x RS{mesh_dim_0}"
dim_partition_dict_mapping = {
"input": {},
@@ -395,17 +413,20 @@ def split_weight_out_channel(self, mesh_dim_0):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
- arg_index=0)
+ arg_index=0,
+ )
communication_action_mapping = {"input": input_comm_action}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def non_split(self):
- name = f'RR = RR x RR'
+ name = f"RR = RR x RR"
dim_partition_dict_mapping = {
"input": {},
@@ -418,13 +439,13 @@ def non_split(self):
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping={})
+ return self.get_sharding_strategy(
+ name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping={}
+ )
@ignore_sharding_exception
def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1):
- name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
+ name = f"S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR"
dim_partition_dict_mapping = {
"input": {
@@ -447,14 +468,16 @@ def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1):
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
- arg_index=1)
+ arg_index=1,
+ )
communication_action_mapping["other"] = other_comm_action
@@ -464,23 +487,27 @@ def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1):
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
- key_for_kwarg='bias')
+ key_for_kwarg="bias",
+ )
communication_action_mapping["bias"] = bias_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1):
- name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
+ name = f"RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R"
dim_partition_dict_mapping = {
"input": {
1: [mesh_dim_0, mesh_dim_1],
@@ -501,17 +528,20 @@ def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1):
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
- comm_type=CommType.AFTER)
+ comm_type=CommType.AFTER,
+ )
communication_action_mapping = {"output": output_comm_action}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_1d_parallel_on_out_channel(self, mesh_dim_0, mesh_dim_1):
- name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
+ name = f"RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}"
dim_partition_dict_mapping = {
"input": {},
"other": {
@@ -535,13 +565,16 @@ def split_1d_parallel_on_out_channel(self, mesh_dim_0, mesh_dim_1):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
- arg_index=0)
+ arg_index=0,
+ )
communication_action_mapping = {"input": input_comm_action}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
def collate_strategies(self) -> List[ShardingStrategy]:
strategies = []
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/embedding_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/embedding_generator.py
index 82a04ab52e73..385a8886f231 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/embedding_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/embedding_generator.py
@@ -1,11 +1,9 @@
import copy
import operator
-import warnings
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
- CommAction,
CommType,
MemoryCost,
ShardingStrategy,
@@ -27,16 +25,16 @@ def validate(self) -> bool:
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
- '''
+ """
Compute the computation cost per device with this specific strategy.
Note: The computation cost for the embedding handler is estimated as dense computing now.
It may not be accurate.
- '''
+ """
# TODO: estimate the embedding computation cost as sparse operation
- sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
- sharded_other_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device()
- sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
+ sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
+ sharded_other_shape = strategy.sharding_specs[self.op_data["other"]].get_sharded_shape_per_device()
+ sharded_output_shape = strategy.sharding_specs[self.op_data["output"]].get_sharded_shape_per_device()
input_size_product = reduce(operator.mul, sharded_input_shape)
other_size_product = reduce(operator.mul, sharded_other_shape)
@@ -55,9 +53,9 @@ def update_compute_cost(self, strategy: ShardingStrategy):
def update_memory_cost(self, strategy: ShardingStrategy):
forward_size_mapping = {
- 'input': self._compute_size_in_bytes(strategy, "input"),
- 'other': self._compute_size_in_bytes(strategy, "other"),
- 'output': self._compute_size_in_bytes(strategy, "output")
+ "input": self._compute_size_in_bytes(strategy, "input"),
+ "other": self._compute_size_in_bytes(strategy, "other"),
+ "output": self._compute_size_in_bytes(strategy, "output"),
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
@@ -75,14 +73,15 @@ def update_memory_cost(self, strategy: ShardingStrategy):
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
- total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
- parameter=fwd_parameter_cost + bwd_parameter_cost)
+ total_mem_cost = MemoryCost(
+ activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost
+ )
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@ignore_sharding_exception
def non_split(self):
- name = f'RR = R x RR'
+ name = f"RR = R x RR"
dim_partition_dict_mapping = {
"input": {},
@@ -92,18 +91,16 @@ def non_split(self):
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping={})
+ return self.get_sharding_strategy(
+ name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping={}
+ )
@ignore_sharding_exception
def split_input(self, mesh_dim_0):
- name = f'S{mesh_dim_0}R = S{mesh_dim_0} x RR'
+ name = f"S{mesh_dim_0}R = S{mesh_dim_0} x RR"
dim_partition_dict_mapping = {
- "input": {
- 0: [mesh_dim_0]
- },
+ "input": {0: [mesh_dim_0]},
"other": {},
"output": {
0: [mesh_dim_0],
@@ -118,7 +115,8 @@ def split_input(self, mesh_dim_0):
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
other_comm_action = self.get_communication_action(
@@ -126,17 +124,20 @@ def split_input(self, mesh_dim_0):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
- arg_index=1)
+ arg_index=1,
+ )
communication_action_mapping["other"] = other_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_input_and_embedding_dim(self, mesh_dim_0, mesh_dim_1):
- name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0} x RS{mesh_dim_1}'
+ name = f"S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0} x RS{mesh_dim_1}"
dim_partition_dict_mapping = {
"input": {
@@ -159,7 +160,8 @@ def split_input_and_embedding_dim(self, mesh_dim_0, mesh_dim_1):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
- arg_index=0)
+ arg_index=0,
+ )
communication_action_mapping = {"input": input_comm_action}
if self.is_param("other"):
@@ -167,7 +169,8 @@ def split_input_and_embedding_dim(self, mesh_dim_0, mesh_dim_1):
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
other_comm_action = self.get_communication_action(
@@ -175,22 +178,23 @@ def split_input_and_embedding_dim(self, mesh_dim_0, mesh_dim_1):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
- arg_index=1)
+ arg_index=1,
+ )
communication_action_mapping["other"] = other_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_1d_parallel_on_input(self, mesh_dim_0, mesh_dim_1):
- name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1} x RR'
+ name = f"S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1} x RR"
dim_partition_dict_mapping = {
- "input": {
- 0: [mesh_dim_0, mesh_dim_1]
- },
+ "input": {0: [mesh_dim_0, mesh_dim_1]},
"other": {},
"output": {
0: [mesh_dim_0, mesh_dim_1],
@@ -207,7 +211,8 @@ def split_1d_parallel_on_input(self, mesh_dim_0, mesh_dim_1):
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
other_comm_action = self.get_communication_action(
@@ -215,17 +220,20 @@ def split_1d_parallel_on_input(self, mesh_dim_0, mesh_dim_1):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
- arg_index=1)
+ arg_index=1,
+ )
communication_action_mapping["other"] = other_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_embedding_dim(self, mesh_dim_0):
- name = f'RS{mesh_dim_0} = R x RS{mesh_dim_0}'
+ name = f"RS{mesh_dim_0} = R x RS{mesh_dim_0}"
dim_partition_dict_mapping = {
"input": {},
@@ -245,17 +253,20 @@ def split_embedding_dim(self, mesh_dim_0):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
- arg_index=0)
+ arg_index=0,
+ )
communication_action_mapping = {"input": input_comm_action}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_1d_parallel_on_embedding_dim(self, mesh_dim_0, mesh_dim_1):
- name = f'RS{mesh_dim_0}{mesh_dim_1} = R x RS{mesh_dim_0}{mesh_dim_1}'
+ name = f"RS{mesh_dim_0}{mesh_dim_1} = R x RS{mesh_dim_0}{mesh_dim_1}"
dim_partition_dict_mapping = {
"input": {},
@@ -275,13 +286,16 @@ def split_1d_parallel_on_embedding_dim(self, mesh_dim_0, mesh_dim_1):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
- arg_index=0)
+ arg_index=0,
+ )
communication_action_mapping = {"input": input_comm_action}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
def collate_strategies(self) -> List[ShardingStrategy]:
strategies = []
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getattr_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getattr_generator.py
index bbeb9a639c83..cc8d5771f28e 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getattr_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getattr_generator.py
@@ -10,7 +10,7 @@
from .strategy_generator import StrategyGenerator
-__all__ = ['GetattrGenerator']
+__all__ = ["GetattrGenerator"]
class GetattrGenerator(StrategyGenerator):
@@ -26,10 +26,10 @@ def update_compute_cost(self, strategy: ShardingStrategy):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
- '''
+ """
Compute the memory cost per device with this specific strategy.
- '''
- forward_size_mapping = {'output': self._compute_size_in_bytes(strategy, "output")}
+ """
+ forward_size_mapping = {"output": self._compute_size_in_bytes(strategy, "output")}
# compute fwd cost incurred
# fwd_cost = output
@@ -47,7 +47,7 @@ def update_memory_cost(self, strategy: ShardingStrategy):
def enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1):
# we check for the output logical shape to get the number of dimensions
dim_partition_list = []
- dim_size = len(self.op_data['output'].logical_shape)
+ dim_size = len(self.op_data["output"].logical_shape)
# enumerate all the 2D sharding cases
sharding_list_2d = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size)
@@ -78,7 +78,8 @@ def enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1):
sharding_strategy = self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ communication_action_mapping=communication_action_mapping,
+ )
strategy_list.append(sharding_strategy)
except ShardingSpecException:
continue
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py
index 0aeb2e0d4079..6f01d9cc7f8e 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py
@@ -1,19 +1,13 @@
import copy
from typing import List
-from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
- CommType,
- MemoryCost,
- ShardingStrategy,
- TrainCycleItem,
-)
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
from colossalai.logging import get_dist_logger
-from colossalai.tensor.shape_consistency import CollectiveCommPattern
from colossalai.tensor.sharding_spec import ShardingSpecException
from .strategy_generator import FollowingStrategyGenerator
-__all__ = ['GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator']
+__all__ = ["GetItemStrategyGenerator", "TensorStrategyGenerator", "TensorTupleStrategyGenerator"]
class GetItemStrategyGenerator(FollowingStrategyGenerator):
@@ -35,12 +29,12 @@ def update_compute_cost(self, strategy: ShardingStrategy):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
- '''
+ """
Compute the memory cost per device with this specific strategy.
- '''
+ """
forward_size_mapping = {
- 'input': self._compute_size_in_bytes(strategy, "input"),
- 'output': self._compute_size_in_bytes(strategy, "output")
+ "input": self._compute_size_in_bytes(strategy, "input"),
+ "output": self._compute_size_in_bytes(strategy, "output"),
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
@@ -58,27 +52,29 @@ def update_memory_cost(self, strategy: ShardingStrategy):
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
- total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
- parameter=fwd_parameter_cost + bwd_parameter_cost)
+ total_mem_cost = MemoryCost(
+ activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost
+ )
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
class TensorStrategyGenerator(GetItemStrategyGenerator):
- '''
+ """
Deal with case 1 and 2.
- '''
+ """
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
- getitem_index = self.op_data['index'].data
+ getitem_index = self.op_data["index"].data
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
try:
logger = get_dist_logger()
dim_partition_dict_mapping = {}
communication_action_mapping = {}
dim_partition_dict_for_input = copy.deepcopy(
- strategy.output_sharding_specs[self.op_data["input"]].dim_partition_dict)
+ strategy.output_sharding_specs[self.op_data["input"]].dim_partition_dict
+ )
int_index = False
if isinstance(getitem_index, int):
@@ -120,9 +116,11 @@ def collate_strategies(self) -> List[ShardingStrategy]:
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}_{index}'
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
except ShardingSpecException as e:
logger.debug(e)
continue
@@ -137,9 +135,9 @@ def collate_strategies(self) -> List[ShardingStrategy]:
class TensorTupleStrategyGenerator(GetItemStrategyGenerator):
- '''
+ """
Deal with case 3.
- '''
+ """
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
@@ -158,13 +156,15 @@ def collate_strategies(self) -> List[ShardingStrategy]:
sharding_spec_mapping["input"] = sharding_spec_for_input
input_sharding_info = f"get the {index} element from ("
for sharding_spec in sharding_spec_for_input:
- input_sharding_info += f'{sharding_spec.sharding_sequence}, '
+ input_sharding_info += f"{sharding_spec.sharding_sequence}, "
input_sharding_info += ")"
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {input_sharding_info}_{strategy_index}'
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
strategy_list.append(strategy)
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py
index 65b173bbf65d..e5b7e6f25d4d 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py
@@ -18,7 +18,7 @@
from .strategy_generator import StrategyGenerator
-__all__ = ['LayerNormGenerator']
+__all__ = ["LayerNormGenerator"]
class LayerNormGenerator(StrategyGenerator):
@@ -31,21 +31,21 @@ def validate(self) -> bool:
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
- '''
+ """
Compute the computation cost per device with this specific strategy.
Note: compute_cost need to be divided by TFLOPS, now it just shows the computation size.
- '''
+ """
# TODO: compute_cost need to be divided by TFLOPS, now it just shows the computation size.
# TODO: a constant coefficient need to be added.
- sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
- sharded_weight_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device()
+ sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
+ sharded_weight_shape = strategy.sharding_specs[self.op_data["other"]].get_sharded_shape_per_device()
if self.has_bias:
# bias add is an element wise operation, so the cost is equal to product of output shape.
bias_compute_cost = reduce(operator.mul, sharded_weight_shape)
# in LayerNorm context, batch dimensions mean all the dimensions do not join the normalization.
- input_batch_shape = sharded_input_shape[:-len(sharded_weight_shape)]
+ input_batch_shape = sharded_input_shape[: -len(sharded_weight_shape)]
input_batch_product = reduce(operator.mul, input_batch_shape, 1)
norm_kernel_product = reduce(operator.mul, sharded_weight_shape, 1)
forward_compute_cost = input_batch_product * norm_kernel_product
@@ -62,18 +62,18 @@ def update_compute_cost(self, strategy: ShardingStrategy):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
- '''
+ """
Compute the memory cost per device with this specific strategy.
- '''
+ """
forward_size_mapping = {
- 'input': self._compute_size_in_bytes(strategy, "input"),
- 'other': self._compute_size_in_bytes(strategy, "other"),
- 'output': self._compute_size_in_bytes(strategy, "output")
+ "input": self._compute_size_in_bytes(strategy, "input"),
+ "other": self._compute_size_in_bytes(strategy, "other"),
+ "output": self._compute_size_in_bytes(strategy, "output"),
}
if self.has_bias:
bias_size = self._compute_size_in_bytes(strategy, "bias")
- forward_size_mapping['bias'] = bias_size
+ forward_size_mapping["bias"] = bias_size
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
@@ -90,8 +90,9 @@ def update_memory_cost(self, strategy: ShardingStrategy):
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
- total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
- parameter=fwd_parameter_cost + bwd_parameter_cost)
+ total_mem_cost = MemoryCost(
+ activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost
+ )
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@@ -120,7 +121,8 @@ def _generate_strategy_with_dim_partition(self, dim_partition):
sharding_spec=sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=total_mesh_dim_list,
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
communication_action_mapping["other"] = other_comm_action
if self.has_bias:
@@ -128,12 +130,15 @@ def _generate_strategy_with_dim_partition(self, dim_partition):
sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=total_mesh_dim_list,
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
communication_action_mapping["bias"] = bias_comm_action
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
return strategy
@@ -155,7 +160,7 @@ def split_input_batch_both_mesh_dim(self, mesh_dim_0, mesh_dim_1, batch_dimensio
@ignore_sharding_exception
def non_split(self):
- name = f'RR = RR x R'
+ name = f"RR = RR x R"
dim_partition_dict_mapping = {
"input": {},
"other": {},
@@ -168,14 +173,16 @@ def non_split(self):
communication_action_mapping = {}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
def collate_strategies(self) -> List[ShardingStrategy]:
- '''
+ """
Generate every possible strategies for a LayerNorm node, and record all strategies into the strategies_vector.
- '''
+ """
strategy_list = []
input_data_dim = len(self.op_data["input"].logical_shape)
weight_data_dim = len(self.op_data["other"].logical_shape)
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py
index aa1581b99e0f..fb182afb9175 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py
@@ -1,5 +1,4 @@
import operator
-from ast import arg
from functools import reduce
from typing import List
@@ -24,14 +23,14 @@ class MatMulStrategyGenerator(StrategyGenerator):
def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
size_mapping = {
- 'input': self._compute_size_in_bytes(strategy, "input"),
- 'other': self._compute_size_in_bytes(strategy, "other"),
- 'output': self._compute_size_in_bytes(strategy, "output")
+ "input": self._compute_size_in_bytes(strategy, "input"),
+ "other": self._compute_size_in_bytes(strategy, "other"),
+ "output": self._compute_size_in_bytes(strategy, "output"),
}
if self.has_bias:
bias_size = self._compute_size_in_bytes(strategy, "bias")
- size_mapping['bias'] = bias_size
+ size_mapping["bias"] = bias_size
# compute fwd cost incurred
# fwd_cost = input + other + bias + output
@@ -41,45 +40,47 @@ def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
# compute bwd cost incurred
# bwd_cost = input_grad + bias_grad
- bwd_activation_cost = sum([v for k, v in size_mapping.items() if k in ['input', 'other', 'bias']])
+ bwd_activation_cost = sum([v for k, v in size_mapping.items() if k in ["input", "other", "bias"]])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=0)
# compute total cost
- total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
- parameter=fwd_parameter_cost + 0)
+ total_mem_cost = MemoryCost(
+ activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + 0
+ )
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
class DotProductStrategyGenerator(MatMulStrategyGenerator):
-
def validate(self) -> bool:
- input_op_data = self.op_data['input']
- other_op_data = self.op_data['other']
+ input_op_data = self.op_data["input"]
+ other_op_data = self.op_data["other"]
assert input_op_data.data.dim() == 1 and other_op_data.data.dim() == 1
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
- sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
+ sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
fwd_compute_cost = sharded_input_shape[0]
bwd_compute_cost = fwd_compute_cost * 2
- compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
- bwd=bwd_compute_cost,
- total=fwd_compute_cost + bwd_compute_cost)
+ compute_cost = TrainCycleItem(
+ fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
+ )
return compute_cost
@ignore_sharding_exception
def no_split(self):
- name = f'R = R dot R'
- dim_partition_dict = {"input": {}, "other": {}, "output": {}, 'bias': {}}
+ name = f"R = R dot R"
+ dim_partition_dict = {"input": {}, "other": {}, "output": {}, "bias": {}}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
communication_action_mapping = {}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_one_dim(self, mesh_dim):
- name = f'R = S{mesh_dim} dot S{mesh_dim}'
+ name = f"R = S{mesh_dim} dot S{mesh_dim}"
# get sharding spec
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "output": {}, "bias": {0: [mesh_dim]}}
@@ -87,14 +88,17 @@ def split_one_dim(self, mesh_dim):
# get communication action
output_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['output'],
+ sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim,
- comm_type=CommType.AFTER)
+ comm_type=CommType.AFTER,
+ )
communication_action_mapping = {"output": output_comm_action}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
@@ -112,19 +116,18 @@ def collate_strategies(self) -> List[ShardingStrategy]:
class MatVecStrategyGenerator(MatMulStrategyGenerator):
-
def validate(self) -> bool:
- input_op_data = self.op_data['input']
- other_op_data = self.op_data['other']
+ input_op_data = self.op_data["input"]
+ other_op_data = self.op_data["other"]
assert input_op_data.data.dim() == 2 and other_op_data.data.dim() == 1
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
- sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
+ sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
fwd_compute_cost = sharded_input_shape[0]
bwd_compute_cost = fwd_compute_cost * 2
- compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
- bwd=bwd_compute_cost,
- total=fwd_compute_cost + bwd_compute_cost)
+ compute_cost = TrainCycleItem(
+ fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
+ )
return compute_cost
@ignore_sharding_exception
@@ -133,67 +136,69 @@ def no_split(self):
dim_partition_dict = {"input": {}, "other": {}, "output": {}}
if self.has_bias:
- dim_partition_dict['bias'] = {}
+ dim_partition_dict["bias"] = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping={})
+ return self.get_sharding_strategy(
+ name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping={}
+ )
@ignore_sharding_exception
def split_input_batch(self, mesh_dim):
- name = f'S{mesh_dim}R = S{mesh_dim}R x R'
+ name = f"S{mesh_dim}R = S{mesh_dim}R x R"
# get sharding spec
dim_partition_dict = {
- "input": {
- 0: [mesh_dim]
- },
+ "input": {0: [mesh_dim]},
"other": {},
- "output": {
- 0: [mesh_dim]
- },
+ "output": {0: [mesh_dim]},
}
if self.has_bias:
- dim_partition_dict['bias'] = {}
+ dim_partition_dict["bias"] = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication action
communication_action_mapping = {}
- if self.is_param('other'):
+ if self.is_param("other"):
other_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['other'],
+ sharding_spec=sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
other_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['other'],
+ sharding_spec=sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.BEFORE,
- arg_index=1)
- communication_action_mapping['other'] = other_comm_action
+ arg_index=1,
+ )
+ communication_action_mapping["other"] = other_comm_action
if self.has_bias:
- if self.is_param('bias'):
+ if self.is_param("bias"):
bias_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['bias'],
+ sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
bias_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['bias'],
+ sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.BEFORE,
- arg_index=2)
- communication_action_mapping['bias'] = bias_comm_action
+ arg_index=2,
+ )
+ communication_action_mapping["bias"] = bias_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
@@ -209,12 +214,13 @@ def collate_strategies(self) -> List[ShardingStrategy]:
class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
-
- def __init__(self,
- operation_data_mapping,
- device_mesh,
- linear_projection_type='linear',
- solver_perference=SolverPerference.STANDARD):
+ def __init__(
+ self,
+ operation_data_mapping,
+ device_mesh,
+ linear_projection_type="linear",
+ solver_perference=SolverPerference.STANDARD,
+ ):
super().__init__(operation_data_mapping, device_mesh)
self.linear_projection_type = linear_projection_type
self.solver_perference = solver_perference
@@ -224,17 +230,17 @@ def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
# C: [M, N], A: [M, P], B: [P, N]
# fwd cost = MNP (only count mul)
# bwd: 2 x fwd_cost
- sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
- sharded_other_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device()
+ sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
+ sharded_other_shape = strategy.sharding_specs[self.op_data["other"]].get_sharded_shape_per_device()
dim_m_val = reduce(operator.mul, sharded_input_shape[:-1])
dim_n_val = sharded_other_shape[-1]
dim_p_val = sharded_other_shape[0]
fwd_compute_cost = dim_m_val * dim_n_val * dim_p_val
bwd_compute_cost = fwd_compute_cost * 2
- compute_cost = TrainCycleItem(fwd=bwd_compute_cost,
- bwd=bwd_compute_cost,
- total=fwd_compute_cost + bwd_compute_cost)
+ compute_cost = TrainCycleItem(
+ fwd=bwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
+ )
strategy.compute_cost = compute_cost
def dp_strategies(self) -> List[ShardingStrategy]:
@@ -301,28 +307,21 @@ def collate_strategies(self) -> List[ShardingStrategy]:
@ignore_sharding_exception
def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
# handle case SS = SR x RS
- name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
+ name = f"S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}"
dim_partition_dict_mapping = {
- "input": {
- 0: [mesh_dim_0]
- },
- "other": {
- -1: [mesh_dim_1]
- },
- "output": {
- 0: [mesh_dim_0],
- -1: [mesh_dim_1]
- },
+ "input": {0: [mesh_dim_0]},
+ "other": {-1: [mesh_dim_1]},
+ "output": {0: [mesh_dim_0], -1: [mesh_dim_1]},
}
# linear bias only has one dimension, but addmm bias has same dimensions
# as the output logically.
- if self.linear_projection_type == 'linear':
- dim_partition_dict_mapping['bias'] = {-1: [mesh_dim_1]}
- elif self.linear_projection_type == 'addmm':
- dim_partition_dict_mapping['bias'] = {0: [mesh_dim_0], -1: [mesh_dim_1]}
+ if self.linear_projection_type == "linear":
+ dim_partition_dict_mapping["bias"] = {-1: [mesh_dim_1]}
+ elif self.linear_projection_type == "addmm":
+ dim_partition_dict_mapping["bias"] = {0: [mesh_dim_0], -1: [mesh_dim_1]}
else:
- raise ('Unsupported linear projection type')
+ raise ("Unsupported linear projection type")
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
@@ -333,75 +332,75 @@ def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
- arg_index=0)
+ arg_index=0,
+ )
- if self.is_param('other'):
+ if self.is_param("other"):
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
- arg_index=1)
+ arg_index=1,
+ )
- communication_action_mapping['input'] = input_comm_action
- communication_action_mapping['other'] = other_comm_action
+ communication_action_mapping["input"] = input_comm_action
+ communication_action_mapping["other"] = other_comm_action
# we only add allreduce comm action for linear bias, because
# allreduce comm action for addmm bias will be considered in post processing
- if self.has_bias and self.linear_projection_type == 'linear':
- if self.is_param('bias'):
+ if self.has_bias and self.linear_projection_type == "linear":
+ if self.is_param("bias"):
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
- key_for_kwarg='bias')
- communication_action_mapping['bias'] = bias_comm_action
+ key_for_kwarg="bias",
+ )
+ communication_action_mapping["bias"] = bias_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
# handle the case SR = SS x SR
- name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
+ name = f"S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R"
# get sharding spec mapping
dim_partition_dict_mapping = {
- "input": {
- 0: [mesh_dim_0],
- -1: [mesh_dim_1]
- },
- "other": {
- 0: [mesh_dim_1]
- },
+ "input": {0: [mesh_dim_0], -1: [mesh_dim_1]},
+ "other": {0: [mesh_dim_1]},
"bias": {},
- "output": {
- 0: [mesh_dim_0]
- },
+ "output": {0: [mesh_dim_0]},
}
# linear bias only has one dimension, but addmm bias has same dimensions
# as the output logically.
- if self.linear_projection_type == 'linear':
- dim_partition_dict_mapping['bias'] = {}
- elif self.linear_projection_type == 'addmm':
- dim_partition_dict_mapping['bias'] = {0: [mesh_dim_0]}
+ if self.linear_projection_type == "linear":
+ dim_partition_dict_mapping["bias"] = {}
+ elif self.linear_projection_type == "addmm":
+ dim_partition_dict_mapping["bias"] = {0: [mesh_dim_0]}
else:
- raise ('Unsupported linear projection type')
+ raise ("Unsupported linear projection type")
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
@@ -412,66 +411,64 @@ def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1,
- comm_type=CommType.AFTER)
+ comm_type=CommType.AFTER,
+ )
- if self.is_param('other'):
+ if self.is_param("other"):
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
- arg_index=1)
+ arg_index=1,
+ )
- communication_action_mapping['other'] = other_comm_action
- communication_action_mapping['output'] = output_comm_action
+ communication_action_mapping["other"] = other_comm_action
+ communication_action_mapping["output"] = output_comm_action
# we only add allreduce comm action for linear bias, because
# allreduce comm action for addmm bias will be considered in post processing
- if self.has_bias and self.linear_projection_type == 'linear':
- if self.is_param('bias'):
+ if self.has_bias and self.linear_projection_type == "linear":
+ if self.is_param("bias"):
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
- key_for_kwarg='bias')
- communication_action_mapping['bias'] = bias_comm_action
+ key_for_kwarg="bias",
+ )
+ communication_action_mapping["bias"] = bias_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
- name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
+ name = f"RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}"
# get sharding specs
dim_partition_dict_mapping = {
- "input": {
- -1: [mesh_dim_0]
- },
- "other": {
- 0: [mesh_dim_0],
- -1: [mesh_dim_1]
- },
- "bias": {
- -1: [mesh_dim_1]
- },
- "output": {
- -1: [mesh_dim_1]
- },
+ "input": {-1: [mesh_dim_0]},
+ "other": {0: [mesh_dim_0], -1: [mesh_dim_1]},
+ "bias": {-1: [mesh_dim_1]},
+ "output": {-1: [mesh_dim_1]},
}
# We don't have to do anything special for bias here, because
@@ -482,34 +479,34 @@ def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
# get communication actions
communication_action_mapping = {}
output_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['output'],
+ sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.AFTER)
+ comm_type=CommType.AFTER,
+ )
input_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['input'],
+ sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
- arg_index=0)
+ arg_index=0,
+ )
communication_action_mapping["input"] = input_comm_action
- communication_action_mapping['output'] = output_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ communication_action_mapping["output"] = output_comm_action
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def recompute_split_both_contract(self, mesh_dim):
- name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
+ name = f"RR = RS{mesh_dim} x S{mesh_dim}R"
# get sharding spec
dim_partition_dict_mapping = {
- "input": {
- -1: [mesh_dim]
- },
- "other": {
- 0: [mesh_dim]
- },
+ "input": {-1: [mesh_dim]},
+ "other": {0: [mesh_dim]},
"bias": {},
"output": {},
}
@@ -520,32 +517,29 @@ def recompute_split_both_contract(self, mesh_dim):
# get communication action
communication_action_mapping = {}
output_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['output'],
+ sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim,
- comm_type=CommType.AFTER)
+ comm_type=CommType.AFTER,
+ )
- communication_action_mapping['output'] = output_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ communication_action_mapping["output"] = output_comm_action
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_rhs_space_only(self, mesh_dim):
- name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
+ name = f"RS{mesh_dim} = RR x RS{mesh_dim}"
# get sharding spec
dim_partition_dict_mapping = {
"input": {},
- "other": {
- -1: [mesh_dim]
- },
- "bias": {
- -1: [mesh_dim]
- },
- "output": {
- -1: [mesh_dim]
- },
+ "other": {-1: [mesh_dim]},
+ "bias": {-1: [mesh_dim]},
+ "output": {-1: [mesh_dim]},
}
# We don't have to do anything special for bias here, because
# the bias is already the same sharding spec as the output.
@@ -554,93 +548,94 @@ def split_rhs_space_only(self, mesh_dim):
# get communication actions
communication_action_mapping = {}
input_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['input'],
+ sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.BEFORE,
- arg_index=0)
+ arg_index=0,
+ )
- communication_action_mapping['input'] = input_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ communication_action_mapping["input"] = input_comm_action
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1):
- name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
+ name = f"S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR"
# get sharding spec
dim_partition_dict_mapping = {
- "input": {
- 0: [mesh_dim_0, mesh_dim_1]
- },
+ "input": {0: [mesh_dim_0, mesh_dim_1]},
"other": {},
"bias": {},
- "output": {
- 0: [mesh_dim_0, mesh_dim_1]
- },
+ "output": {0: [mesh_dim_0, mesh_dim_1]},
}
# linear bias only has one dimension, but addmm bias has same dimensions
# as the output logically.
- if self.linear_projection_type == 'linear':
- dim_partition_dict_mapping['bias'] = {}
- elif self.linear_projection_type == 'addmm':
- dim_partition_dict_mapping['bias'] = {0: [mesh_dim_0, mesh_dim_1]}
+ if self.linear_projection_type == "linear":
+ dim_partition_dict_mapping["bias"] = {}
+ elif self.linear_projection_type == "addmm":
+ dim_partition_dict_mapping["bias"] = {0: [mesh_dim_0, mesh_dim_1]}
else:
- raise ('Unsupported linear projection type')
+ raise ("Unsupported linear projection type")
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication action
communication_action_mapping = {}
- if self.is_param('other'):
+ if self.is_param("other"):
other_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['other'],
+ sharding_spec=sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
other_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['other'],
+ sharding_spec=sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
- arg_index=1)
- communication_action_mapping['other'] = other_comm_action
+ arg_index=1,
+ )
+ communication_action_mapping["other"] = other_comm_action
# we only add allreduce comm action for linear bias, because
# allreduce comm action for addmm bias will be considered in post processing
- if self.has_bias and self.linear_projection_type == 'linear':
- if self.is_param('bias'):
+ if self.has_bias and self.linear_projection_type == "linear":
+ if self.is_param("bias"):
bias_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['bias'],
+ sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
bias_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['bias'],
+ sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
- key_for_kwarg='bias')
- communication_action_mapping['bias'] = bias_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ key_for_kwarg="bias",
+ )
+ communication_action_mapping["bias"] = bias_comm_action
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
- name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
+ name = f"RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R"
# get sharding spec
dim_partition_dict_mapping = {
- "input": {
- -1: [mesh_dim_0, mesh_dim_1]
- },
- "other": {
- 0: [mesh_dim_0, mesh_dim_1]
- },
+ "input": {-1: [mesh_dim_0, mesh_dim_1]},
+ "other": {0: [mesh_dim_0, mesh_dim_1]},
"bias": {},
"output": {},
}
@@ -652,32 +647,29 @@ def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
# get communication action
communication_action_mapping = {}
output_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['output'],
+ sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
- comm_type=CommType.AFTER)
- communication_action_mapping['output'] = output_comm_action
+ comm_type=CommType.AFTER,
+ )
+ communication_action_mapping["output"] = output_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
- name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
+ name = f"RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}"
# get sharding spec
dim_partition_dict_mapping = {
"input": {},
- "other": {
- -1: [mesh_dim_0, mesh_dim_1]
- },
- "bias": {
- -1: [mesh_dim_0, mesh_dim_1]
- },
- "output": {
- -1: [mesh_dim_0, mesh_dim_1]
- },
+ "other": {-1: [mesh_dim_0, mesh_dim_1]},
+ "bias": {-1: [mesh_dim_0, mesh_dim_1]},
+ "output": {-1: [mesh_dim_0, mesh_dim_1]},
}
# We don't have to do anything special for bias here, because
@@ -687,20 +679,23 @@ def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
# get communication action
communication_action_mapping = {}
input_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['input'],
+ sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
- arg_index=0)
- communication_action_mapping['input'] = input_comm_action
+ arg_index=0,
+ )
+ communication_action_mapping["input"] = input_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def non_split(self):
- name = f'RR = RR x RR'
+ name = f"RR = RR x RR"
# get sharding spec
dim_partition_dict_mapping = {
@@ -717,22 +712,24 @@ def non_split(self):
# get communication action
communication_action_mapping = {}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
def validate(self) -> bool:
assert "input" in self.op_data
assert "other" in self.op_data
# make sure the other has 2 dim
- input_data = self.op_data['input']
- other_data = self.op_data['other']
+ input_data = self.op_data["input"]
+ other_data = self.op_data["other"]
assert input_data.data.dim() > 0 and other_data.data.dim() == 2
assert other_data.logical_shape[0] == input_data.logical_shape[-1]
if self.has_bias:
- bias_data = self.op_data['bias']
+ bias_data = self.op_data["bias"]
assert bias_data.logical_shape[-1] == other_data.logical_shape[-1]
@@ -757,37 +754,38 @@ def __init__(self, *args, **kwargs):
def _pop_batch_dim_sharding_for_output(self, dim_partition_dict):
# remove partition dict for dim 0
- dim_partition_dict['output'].pop(0, None)
+ dim_partition_dict["output"].pop(0, None)
# decrease the remaining dim index by 1
temp_dim_partition = {}
- keys = list(dim_partition_dict['output'].keys())
+ keys = list(dim_partition_dict["output"].keys())
for key in keys:
- val = dim_partition_dict['output'].pop(key)
+ val = dim_partition_dict["output"].pop(key)
temp_dim_partition[key - 1] = val
- dim_partition_dict['output'].update(temp_dim_partition)
+ dim_partition_dict["output"].update(temp_dim_partition)
def validate(self) -> bool:
- input_op_data = self.op_data['input']
- other_op_data = self.op_data['other']
+ input_op_data = self.op_data["input"]
+ other_op_data = self.op_data["other"]
assert len(input_op_data.logical_shape) == 3 or len(other_op_data.logical_shape) == 3
- if 'bias' in self.op_data:
- bias_op_data = self.op_data['bias']
+ if "bias" in self.op_data:
+ bias_op_data = self.op_data["bias"]
assert bias_op_data.data.dim() < 3 and len(bias_op_data.logical_shape) == 2
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
- fwd_compute_cost = self.op_data['input'].data.shape[-1] * reduce(operator.mul,
- self.op_data['output'].data.shape)
+ fwd_compute_cost = self.op_data["input"].data.shape[-1] * reduce(
+ operator.mul, self.op_data["output"].data.shape
+ )
bwd_compute_cost = fwd_compute_cost * 2
- compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
- bwd=bwd_compute_cost,
- total=fwd_compute_cost + bwd_compute_cost)
+ compute_cost = TrainCycleItem(
+ fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
+ )
strategy.compute_cost = compute_cost
@ignore_sharding_exception
def split_one_batch_dim(self, mesh_dim):
- name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}'
+ name = f"Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}"
# get sharding_spec
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "bias": {}, "output": {0: [mesh_dim]}}
@@ -799,30 +797,27 @@ def split_one_batch_dim(self, mesh_dim):
communication_action_mapping = {}
if self.has_bias:
bias_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['bias'],
+ sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.BEFORE,
- arg_index=0)
- communication_action_mapping['bias'] = bias_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ arg_index=0,
+ )
+ communication_action_mapping["bias"] = bias_comm_action
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_two_batch_dim(self, mesh_dim_0, mesh_dim_1):
- name = f'Sb{mesh_dim_0}{mesh_dim_1} = Sb{mesh_dim_0}{mesh_dim_1} x Sb{mesh_dim_0}{mesh_dim_1}'
+ name = f"Sb{mesh_dim_0}{mesh_dim_1} = Sb{mesh_dim_0}{mesh_dim_1} x Sb{mesh_dim_0}{mesh_dim_1}"
dim_partition_dict = {
- "input": {
- 0: [mesh_dim_0, mesh_dim_1]
- },
- "other": {
- 0: [mesh_dim_0, mesh_dim_1]
- },
+ "input": {0: [mesh_dim_0, mesh_dim_1]},
+ "other": {0: [mesh_dim_0, mesh_dim_1]},
"bias": {},
- "output": {
- 0: [mesh_dim_0, mesh_dim_1]
- }
+ "output": {0: [mesh_dim_0, mesh_dim_1]},
}
if self.squeeze_batch_dim:
self._pop_batch_dim_sharding_for_output(dim_partition_dict)
@@ -832,35 +827,28 @@ def split_two_batch_dim(self, mesh_dim_0, mesh_dim_1):
communication_action_mapping = {}
if self.has_bias:
bias_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['bias'],
+ sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
- arg_index=0)
- communication_action_mapping['bias'] = bias_comm_action
+ arg_index=0,
+ )
+ communication_action_mapping["bias"] = bias_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_batch_dim_lhs_space(self, mesh_dim_0, mesh_dim_1):
- name = f'Sb{mesh_dim_0}Si{mesh_dim_1} = Sb{mesh_dim_0}Si{mesh_dim_1} x Sb{mesh_dim_0}'
+ name = f"Sb{mesh_dim_0}Si{mesh_dim_1} = Sb{mesh_dim_0}Si{mesh_dim_1} x Sb{mesh_dim_0}"
dim_partition_dict = {
- "input": {
- 0: [mesh_dim_0],
- 1: [mesh_dim_1]
- },
- "other": {
- 0: [mesh_dim_0]
- },
- "bias": {
- 0: [mesh_dim_1]
- },
- "output": {
- 0: [mesh_dim_0],
- 1: [mesh_dim_1]
- }
+ "input": {0: [mesh_dim_0], 1: [mesh_dim_1]},
+ "other": {0: [mesh_dim_0]},
+ "bias": {0: [mesh_dim_1]},
+ "output": {0: [mesh_dim_0], 1: [mesh_dim_1]},
}
if self.squeeze_batch_dim:
self._pop_batch_dim_sharding_for_output(dim_partition_dict)
@@ -869,46 +857,40 @@ def split_batch_dim_lhs_space(self, mesh_dim_0, mesh_dim_1):
# get communication actions
communication_action_mapping = {}
other_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['other'],
+ sharding_spec=sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
- arg_index=1)
- communication_action_mapping['other'] = other_comm_action
+ arg_index=1,
+ )
+ communication_action_mapping["other"] = other_comm_action
if self.has_bias:
bias_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['bias'],
+ sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
- arg_index=0)
- communication_action_mapping['bias'] = bias_comm_action
+ arg_index=0,
+ )
+ communication_action_mapping["bias"] = bias_comm_action
# for addbmm case, other is the third argument instead of second.
- communication_action_mapping['other'].arg_index += 1
+ communication_action_mapping["other"].arg_index += 1
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_batch_dim_rhs_space(self, mesh_dim_0, mesh_dim_1):
- name = f'Sb{mesh_dim_0}Sj{mesh_dim_1} = Sb{mesh_dim_0}R x Sb{mesh_dim_0}Sj{mesh_dim_1}'
+ name = f"Sb{mesh_dim_0}Sj{mesh_dim_1} = Sb{mesh_dim_0}R x Sb{mesh_dim_0}Sj{mesh_dim_1}"
dim_partition_dict = {
- "input": {
- 0: [mesh_dim_0]
- },
- "other": {
- 0: [mesh_dim_0],
- 2: [mesh_dim_1]
- },
- "bias": {
- 1: [mesh_dim_1]
- },
- "output": {
- 0: [mesh_dim_0],
- 2: [mesh_dim_1]
- }
+ "input": {0: [mesh_dim_0]},
+ "other": {0: [mesh_dim_0], 2: [mesh_dim_1]},
+ "bias": {1: [mesh_dim_1]},
+ "output": {0: [mesh_dim_0], 2: [mesh_dim_1]},
}
if self.squeeze_batch_dim:
self._pop_batch_dim_sharding_for_output(dim_partition_dict)
@@ -917,43 +899,41 @@ def split_batch_dim_rhs_space(self, mesh_dim_0, mesh_dim_1):
# get communication actions
communication_action_mapping = {}
input_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['input'],
+ sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
- arg_index=0)
- communication_action_mapping['input'] = input_comm_action
+ arg_index=0,
+ )
+ communication_action_mapping["input"] = input_comm_action
if self.has_bias:
bias_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['bias'],
+ sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.BEFORE)
- communication_action_mapping['bias'] = bias_comm_action
+ comm_type=CommType.BEFORE,
+ )
+ communication_action_mapping["bias"] = bias_comm_action
# for addbmm case, other is the second argument instead of first.
- communication_action_mapping['input'].arg_index += 1
+ communication_action_mapping["input"].arg_index += 1
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_batch_dim_both_contract(self, mesh_dim_0, mesh_dim_1):
- name = f'Sb{mesh_dim_0}R = Sb{mesh_dim_0}Sk{mesh_dim_1} x Sb{mesh_dim_0}Sk{mesh_dim_1}'
+ name = f"Sb{mesh_dim_0}R = Sb{mesh_dim_0}Sk{mesh_dim_1} x Sb{mesh_dim_0}Sk{mesh_dim_1}"
dim_partition_dict = {
- "input": {
- 0: [mesh_dim_0],
- 2: [mesh_dim_1]
- },
- "other": {
- 0: [mesh_dim_0],
- 1: [mesh_dim_1]
- },
+ "input": {0: [mesh_dim_0], 2: [mesh_dim_1]},
+ "other": {0: [mesh_dim_0], 1: [mesh_dim_1]},
"bias": {},
"output": {
0: [mesh_dim_0],
- }
+ },
}
if self.squeeze_batch_dim:
self._pop_batch_dim_sharding_for_output(dim_partition_dict)
@@ -962,24 +942,28 @@ def split_batch_dim_both_contract(self, mesh_dim_0, mesh_dim_1):
# get communication actions
communication_action_mapping = {}
output_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['output'],
+ sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1,
- comm_type=CommType.AFTER)
- communication_action_mapping['output'] = output_comm_action
+ comm_type=CommType.AFTER,
+ )
+ communication_action_mapping["output"] = output_comm_action
if self.has_bias:
bias_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['bias'],
+ sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
- arg_index=0)
- communication_action_mapping['bias'] = bias_comm_action
-
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ arg_index=0,
+ )
+ communication_action_mapping["bias"] = bias_comm_action
+
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py
index b7db42f8f67e..b307e38b5b6d 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py
@@ -21,28 +21,31 @@ class NormalPoolStrategyGenerator(StrategyGenerator):
"""
def validate(self) -> bool:
- '''
+ """
In sanity check, we need make sure the input data having correct dimension size.
For Pool1d, the dim of input data should be 3([N, C, L]).
For Pool2d, the dim of input data should be 4([N, C, H, W]).
For Pool3d, the dim of input data should be 5([N, C, H, W, D]).
- '''
- input_op_data = self.op_data['input']
+ """
+ input_op_data = self.op_data["input"]
assert input_op_data.data.dim() in (
- 3, 4, 5), f'We suppose the dim of input fed into Pool op should in range of [3, 5].'
+ 3,
+ 4,
+ 5,
+ ), f"We suppose the dim of input fed into Pool op should in range of [3, 5]."
def update_compute_cost(self, strategy: ShardingStrategy) -> TrainCycleItem:
- '''
+ """
Compute the computation cost per device with this specific strategy.
Note: compute_cost need to be divided by TFLOPS, now it just shows the computation size.
- '''
+ """
# TODO: compute_cost need to be divided by TFLOPS, now it just shows the computation size.
# 1D: (Lout) * N * C * kernel
# 2D: (H * W) * N * Cout * Cin * kernel
# 3D: (H * W * D) * N * Cout * Cin * kernel
- sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
- sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
+ sharded_output_shape = strategy.sharding_specs[self.op_data["output"]].get_sharded_shape_per_device()
+ sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
kernel_size = self.op_data["other"].data
if isinstance(kernel_size, int):
@@ -61,8 +64,8 @@ def update_compute_cost(self, strategy: ShardingStrategy) -> TrainCycleItem:
def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
forward_size_mapping = {
- 'input': self._compute_size_in_bytes(strategy, "input"),
- 'output': self._compute_size_in_bytes(strategy, "output")
+ "input": self._compute_size_in_bytes(strategy, "input"),
+ "output": self._compute_size_in_bytes(strategy, "output"),
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
@@ -88,12 +91,16 @@ def _generate_strategy_with_dim_partition(self, dim_partition):
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
- name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}'
+ name = (
+ f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}'
+ )
communication_action_mapping = {}
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
return strategy
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py
index 69d1642d4f80..33fb1ac5c5be 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py
@@ -12,7 +12,7 @@
from .strategy_generator import OutputStrategyGenerator
-__all__ = ['OutputGenerator']
+__all__ = ["OutputGenerator"]
class OutputGenerator(OutputStrategyGenerator):
@@ -20,8 +20,13 @@ class OutputGenerator(OutputStrategyGenerator):
OutputGenerator is a generic class to generate strategies for Output Node.
"""
- def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh,
- predecessor_nodes: List[Node], output_option: str):
+ def __init__(
+ self,
+ operation_data_mapping: Dict[str, OperationData],
+ device_mesh: DeviceMesh,
+ predecessor_nodes: List[Node],
+ output_option: str,
+ ):
super().__init__(operation_data_mapping, device_mesh, predecessor_nodes)
self.output_option = output_option
@@ -33,9 +38,9 @@ def update_compute_cost(self, strategy: ShardingStrategy):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
- '''
+ """
Compute the memory cost per device with this specific strategy.
- '''
+ """
fwd_mem_cost = MemoryCost(activation=0, parameter=0)
bwd_mem_cost = MemoryCost(activation=0, parameter=0)
@@ -65,16 +70,18 @@ def replica_strategy(self) -> List[ShardingStrategy]:
else:
dim_partition_dict_for_output = tuple(dim_partition_dict_for_output)
- dim_partition_dict_mapping['output'] = dim_partition_dict_for_output
+ dim_partition_dict_mapping["output"] = dim_partition_dict_for_output
communication_action_mapping = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
- name = 'Replica Output'
+ name = "Replica Output"
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
return strategy
def distributed_strategy(self, mesh_list: List[List[int]] = None) -> List[ShardingStrategy]:
@@ -82,19 +89,15 @@ def distributed_strategy(self, mesh_list: List[List[int]] = None) -> List[Shardi
Generate distributed strategy for output node.
"""
# TODO: need to take care of the case when the first element of output only need to be sharded.
- output_op_data = self.op_data['output']
+ output_op_data = self.op_data["output"]
if isinstance(output_op_data.data, tuple):
length = len(output_op_data.data)
dim_partition_dict_mapping = {
- "output": [{
- 0: mesh_list
- }] * length,
+ "output": [{0: mesh_list}] * length,
}
else:
dim_partition_dict_mapping = {
- "output": {
- 0: mesh_list
- },
+ "output": {0: mesh_list},
}
for index, _ in enumerate(self.predecessor_nodes):
mapping_name = f"input_{index}"
@@ -103,19 +106,21 @@ def distributed_strategy(self, mesh_list: List[List[int]] = None) -> List[Shardi
communication_action_mapping = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
- name = 'Distributed Output'
+ name = "Distributed Output"
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
return strategy
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
mesh_list = [0, 1]
- if self.output_option == 'replicated':
+ if self.output_option == "replicated":
strategy_list.append(self.replica_strategy())
- elif self.output_option == 'distributed':
+ elif self.output_option == "distributed":
strategy_list.append(self.distributed_strategy(mesh_list))
return strategy_list
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py
index 779a7ced93bb..df0862a396d2 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py
@@ -10,7 +10,7 @@
from .strategy_generator import StrategyGenerator
-__all__ = ['PlaceholderGenerator']
+__all__ = ["PlaceholderGenerator"]
class PlaceholderGenerator(StrategyGenerator):
@@ -18,8 +18,9 @@ class PlaceholderGenerator(StrategyGenerator):
PlaceholderGenerator is a generic class to generate strategies for placeholder node.
"""
- def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh,
- placeholder_option: str):
+ def __init__(
+ self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, placeholder_option: str
+ ):
super().__init__(operation_data_mapping, device_mesh)
self.placeholder_option = placeholder_option
@@ -31,10 +32,10 @@ def update_compute_cost(self, strategy: ShardingStrategy):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
- '''
+ """
Compute the memory cost per device with this specific strategy.
- '''
- forward_size_mapping = {'output': self._compute_size_in_bytes(strategy, "output")}
+ """
+ forward_size_mapping = {"output": self._compute_size_in_bytes(strategy, "output")}
# compute fwd cost incurred
# fwd_cost = output
@@ -58,11 +59,13 @@ def replica_placeholder(self) -> ShardingStrategy:
communication_action_mapping = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
- name = 'Replica Placeholder'
+ name = "Replica Placeholder"
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
return strategy
@@ -71,29 +74,31 @@ def distributed_placeholder(self, mesh_list) -> ShardingStrategy:
Generate distributed strategy for placeholder node.
"""
dim_partition_dict_mapping = {
- "output": {
- 0: mesh_list
- },
+ "output": {0: mesh_list},
}
communication_action_mapping = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
- name = 'Distributed Placeholder'
+ name = "Distributed Placeholder"
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
return strategy
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
- if self.placeholder_option == 'distributed':
+ if self.placeholder_option == "distributed":
mesh_list = [0, 1]
distributed_strategy = self.distributed_placeholder(mesh_list)
strategy_list.append(distributed_strategy)
else:
- assert self.placeholder_option == 'replicated', f'placeholder_option {self.placeholder_option} is not supported'
+ assert (
+ self.placeholder_option == "replicated"
+ ), f"placeholder_option {self.placeholder_option} is not supported"
replicated_strategy = self.replica_placeholder()
strategy_list.append(replicated_strategy)
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py
index 24f75e352935..48f454553ac7 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py
@@ -17,7 +17,7 @@
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from colossalai.tensor.sharding_spec import ShardingSpec
-__all__ = ['ReshapeGenerator', 'ViewGenerator', 'PermuteGenerator', 'TransposeGenerator', 'SplitGenerator']
+__all__ = ["ReshapeGenerator", "ViewGenerator", "PermuteGenerator", "TransposeGenerator", "SplitGenerator"]
class ReshapeGenerator(FollowingStrategyGenerator):
@@ -33,12 +33,12 @@ def update_compute_cost(self, strategy: ShardingStrategy):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
- '''
+ """
Compute the memory cost per device with this specific strategy.
- '''
+ """
forward_size_mapping = {
- 'input': self._compute_size_in_bytes(strategy, "input"),
- 'output': self._compute_size_in_bytes(strategy, "output")
+ "input": self._compute_size_in_bytes(strategy, "input"),
+ "output": self._compute_size_in_bytes(strategy, "output"),
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
@@ -56,8 +56,9 @@ def update_memory_cost(self, strategy: ShardingStrategy):
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
- total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
- parameter=fwd_parameter_cost + bwd_parameter_cost)
+ total_mem_cost = MemoryCost(
+ activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost
+ )
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@@ -77,8 +78,8 @@ def collate_strategies(self) -> List[ShardingStrategy]:
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
- origin_shape = self.op_data['input'].data.shape
- tgt_shape = self.op_data['tgt_shape'].data
+ origin_shape = self.op_data["input"].data.shape
+ tgt_shape = self.op_data["tgt_shape"].data
reshape_mapping_dict = detect_reshape_mapping(origin_shape, tgt_shape)
@@ -86,8 +87,9 @@ def collate_strategies(self) -> List[ShardingStrategy]:
keep_sharding_status = check_keep_sharding_status(dim_partition_dict_for_input, reshape_mapping_dict)
if keep_sharding_status:
- dim_partition_dict_for_output = infer_output_dim_partition_dict(dim_partition_dict_for_input,
- reshape_mapping_dict)
+ dim_partition_dict_for_output = infer_output_dim_partition_dict(
+ dim_partition_dict_for_input, reshape_mapping_dict
+ )
else:
dim_partition_dict_for_output = {}
@@ -119,7 +121,8 @@ def collate_strategies(self) -> List[ShardingStrategy]:
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
logical_process_axis=total_mesh_dim_list,
comm_type=CommType.BEFORE,
- arg_index=0)
+ arg_index=0,
+ )
# it will gather the input through gather_dim during forward phase.
input_comm_action.comm_spec.gather_dim = shard_dim
# it will split the input activation grad through shard_dim during backward phase.
@@ -127,10 +130,10 @@ def collate_strategies(self) -> List[ShardingStrategy]:
elif len(total_mesh_dim_list) >= 2:
source_spec = sharding_spec_mapping["input"]
- target_spec = ShardingSpec(device_mesh=self.device_mesh,
- entire_shape=source_spec.entire_shape,
- dim_partition_dict={})
- comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec}
+ target_spec = ShardingSpec(
+ device_mesh=self.device_mesh, entire_shape=source_spec.entire_shape, dim_partition_dict={}
+ )
+ comm_spec = {"src_spec": source_spec, "tgt_spec": target_spec}
input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0)
else:
@@ -139,9 +142,11 @@ def collate_strategies(self) -> List[ShardingStrategy]:
if input_comm_action is not None:
communication_action_mapping["input"] = input_comm_action
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
strategy_list.append(strategy)
return strategy_list
@@ -159,7 +164,7 @@ def collate_strategies(self) -> List[ShardingStrategy]:
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
- permute_dims = self.op_data['permute_dims'].data
+ permute_dims = self.op_data["permute_dims"].data
dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
dim_partition_dict_for_output = {}
for dim_index, permute_dim in enumerate(permute_dims):
@@ -177,9 +182,11 @@ def collate_strategies(self) -> List[ShardingStrategy]:
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
strategy_list.append(strategy)
return strategy_list
@@ -199,7 +206,7 @@ def collate_strategies(self) -> List[ShardingStrategy]:
dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
dim_partition_dict_for_output = {}
- transpose_dims = self.op_data['transpose_dims'].data
+ transpose_dims = self.op_data["transpose_dims"].data
dim_0 = transpose_dims[0]
dim_1 = transpose_dims[1]
for dim, sharded_dims in dim_partition_dict_for_input.items():
@@ -221,9 +228,11 @@ def collate_strategies(self) -> List[ShardingStrategy]:
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
strategy_list.append(strategy)
return strategy_list
@@ -242,7 +251,7 @@ def collate_strategies(self) -> List[ShardingStrategy]:
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict)
- split_size, split_dim = self.op_data['split_info'].data
+ split_size, split_dim = self.op_data["split_info"].data
if split_dim in dim_partition_dict_for_input:
recover_dims = dim_partition_dict_for_input.pop(split_dim)
@@ -271,7 +280,8 @@ def collate_strategies(self) -> List[ShardingStrategy]:
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
logical_process_axis=recover_dims,
comm_type=CommType.BEFORE,
- arg_index=0)
+ arg_index=0,
+ )
# it will gather the input through gather_dim during forward phase.
input_comm_action.comm_spec.gather_dim = split_dim
# it will split the input activation grad through split_dim during backward phase.
@@ -282,7 +292,7 @@ def collate_strategies(self) -> List[ShardingStrategy]:
source_spec = input_sharding_spec
# target sharding spec
target_spec = sharding_spec_mapping["input"]
- comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec}
+ comm_spec = {"src_spec": source_spec, "tgt_spec": target_spec}
input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0)
else:
@@ -291,9 +301,11 @@ def collate_strategies(self) -> List[ShardingStrategy]:
if input_comm_action is not None:
communication_action_mapping["input"] = input_comm_action
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
strategy_list.append(strategy)
return strategy_list
@@ -341,16 +353,17 @@ def collate_strategies(self) -> List[ShardingStrategy]:
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
logical_process_axis=total_mesh_dim_list,
comm_type=CommType.BEFORE,
- arg_index=0)
+ arg_index=0,
+ )
input_comm_action.comm_spec.gather_dim = total_mesh_dim_list
input_comm_action.comm_spec.shard_dim = total_mesh_dim_list
elif len(total_mesh_dim_list) >= 2:
source_spec = sharding_spec_mapping["input"]
- target_spec = ShardingSpec(device_mesh=self.device_mesh,
- entire_shape=source_spec.entire_shape,
- dim_partition_dict={})
- comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec}
+ target_spec = ShardingSpec(
+ device_mesh=self.device_mesh, entire_shape=source_spec.entire_shape, dim_partition_dict={}
+ )
+ comm_spec = {"src_spec": source_spec, "tgt_spec": target_spec}
input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0)
else:
@@ -358,9 +371,11 @@ def collate_strategies(self) -> List[ShardingStrategy]:
if input_comm_action is not None:
communication_action_mapping["input"] = input_comm_action
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
strategy_list.append(strategy)
return strategy_list
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/softmax_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/softmax_generator.py
index a1ebadd043e2..d4382f9941d2 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/softmax_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/softmax_generator.py
@@ -4,21 +4,9 @@
from typing import List
from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator
-from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
- CommAction,
- CommType,
- MemoryCost,
- ShardingStrategy,
- TrainCycleItem,
-)
-from colossalai.auto_parallel.tensor_shard.utils import (
- check_keep_sharding_status,
- detect_reshape_mapping,
- infer_output_dim_partition_dict,
-)
-from colossalai.tensor.shape_consistency import CollectiveCommPattern
-
-__all__ = ['SoftmaxGenerator']
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
+
+__all__ = ["SoftmaxGenerator"]
class SoftmaxGenerator(FollowingStrategyGenerator):
@@ -30,11 +18,11 @@ def validate(self) -> bool:
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
- '''
+ """
Compute the computation cost per device with this specific strategy.
- '''
- sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
- sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
+ """
+ sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
+ sharded_output_shape = strategy.sharding_specs[self.op_data["output"]].get_sharded_shape_per_device()
input_size_product = reduce(operator.mul, sharded_input_shape)
output_size_product = reduce(operator.mul, sharded_output_shape)
@@ -45,12 +33,12 @@ def update_compute_cost(self, strategy: ShardingStrategy):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
- '''
+ """
Compute the memory cost per device with this specific strategy.
- '''
+ """
forward_size_mapping = {
- 'input': self._compute_size_in_bytes(strategy, "input"),
- 'output': self._compute_size_in_bytes(strategy, "output")
+ "input": self._compute_size_in_bytes(strategy, "input"),
+ "output": self._compute_size_in_bytes(strategy, "output"),
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
@@ -68,8 +56,9 @@ def update_memory_cost(self, strategy: ShardingStrategy):
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
- total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
- parameter=fwd_parameter_cost + bwd_parameter_cost)
+ total_mem_cost = MemoryCost(
+ activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost
+ )
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@@ -80,10 +69,10 @@ def collate_strategies(self) -> List[ShardingStrategy]:
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict)
- softmax_dim = self.op_data['softmax_dim'].data
+ softmax_dim = self.op_data["softmax_dim"].data
if softmax_dim in dim_partition_dict_for_input:
- recover_dims = dim_partition_dict_for_input.pop(softmax_dim)
+ dim_partition_dict_for_input.pop(softmax_dim)
dim_partition_dict_for_output = copy.deepcopy(dim_partition_dict_for_input)
dim_partition_dict_mapping = {
@@ -96,9 +85,11 @@ def collate_strategies(self) -> List[ShardingStrategy]:
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
strategy_list.append(strategy)
return strategy_list
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py
index d42429745c61..7bf2c8cc12a3 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py
@@ -39,7 +39,7 @@ def has_bias(self):
"""
A utility method to check for the existence of bias operand for convenience.
"""
- return 'bias' in self.op_data
+ return "bias" in self.op_data
def is_param(self, op_data_name):
other_data = self.op_data[op_data_name]
@@ -49,8 +49,12 @@ def is_buffer(self, op_data_name):
other_data = self.op_data[op_data_name]
return other_data.type == OperationDataType.BUFFER
- def get_sharding_strategy(self, name: str, sharding_spec_mapping: Dict[str, ShardingSpec],
- communication_action_mapping: Dict[str, CommSpec]):
+ def get_sharding_strategy(
+ self,
+ name: str,
+ sharding_spec_mapping: Dict[str, ShardingSpec],
+ communication_action_mapping: Dict[str, CommSpec],
+ ):
"""
A factory method to produce a ShardingStrategy object.
@@ -80,24 +84,28 @@ def to_sharding_spec_mapping(self, mapping: Dict[str, Dict[int, List[int]]]):
op_data = self.op_data[op_data_name]
def _to_sharding_spec(
- data: any, logical_shape: any,
- dim_partition_dict: Dict[int, List[int]]) -> Union[ShardingSpec, List[ShardingSpec], None]:
+ data: any, logical_shape: any, dim_partition_dict: Dict[int, List[int]]
+ ) -> Union[ShardingSpec, List[ShardingSpec], None]:
"""
This is a recursive function to convert the dim partition dict to a ShardingSpec object.
"""
if isinstance(data, torch.Tensor):
dim_size = len(logical_shape)
dim_partition_dict = convert_dim_partition_dict(dim_size, dim_partition_dict)
- sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
- entire_shape=logical_shape,
- dim_partition_dict=dim_partition_dict)
+ sharding_spec = ShardingSpec(
+ device_mesh=self.device_mesh,
+ entire_shape=logical_shape,
+ dim_partition_dict=dim_partition_dict,
+ )
return sharding_spec
elif isinstance(data, (list, tuple)):
sharding_spec = []
for data_element, logical_shape_element, dim_partition_dict_element in zip(
- data, logical_shape, dim_partition_dict):
+ data, logical_shape, dim_partition_dict
+ ):
sharding_spec.append(
- _to_sharding_spec(data_element, logical_shape_element, dim_partition_dict_element))
+ _to_sharding_spec(data_element, logical_shape_element, dim_partition_dict_element)
+ )
return sharding_spec
else:
return None
@@ -116,31 +124,41 @@ def replace_op_name_with_op_data(self, mapping: Dict[str, Any]):
results[op_data] = v
return results
- def get_communication_spec(self, sharding_spec: ShardingSpec, communication_pattern: CollectiveCommPattern,
- logical_process_axis: Union[int, List[int]]):
+ def get_communication_spec(
+ self,
+ sharding_spec: ShardingSpec,
+ communication_pattern: CollectiveCommPattern,
+ logical_process_axis: Union[int, List[int]],
+ ):
"""
A factory method to produce a CommSpec object.
"""
- return CommSpec(comm_pattern=communication_pattern,
- sharding_spec=sharding_spec,
- logical_process_axis=logical_process_axis)
-
- def get_communication_action(self,
- sharding_spec: ShardingSpec,
- communication_pattern: CollectiveCommPattern,
- logical_process_axis: Union[int, List[int]],
- comm_type: CommType,
- arg_index: int = -1,
- key_for_kwarg: any = None) -> CommAction:
+ return CommSpec(
+ comm_pattern=communication_pattern, sharding_spec=sharding_spec, logical_process_axis=logical_process_axis
+ )
+
+ def get_communication_action(
+ self,
+ sharding_spec: ShardingSpec,
+ communication_pattern: CollectiveCommPattern,
+ logical_process_axis: Union[int, List[int]],
+ comm_type: CommType,
+ arg_index: int = -1,
+ key_for_kwarg: any = None,
+ ) -> CommAction:
"""
A factory method to produce a CommAction object.
"""
- return CommAction(comm_spec=self.get_communication_spec(sharding_spec=sharding_spec,
- communication_pattern=communication_pattern,
- logical_process_axis=logical_process_axis),
- comm_type=comm_type,
- arg_index=arg_index,
- key_for_kwarg=key_for_kwarg)
+ return CommAction(
+ comm_spec=self.get_communication_spec(
+ sharding_spec=sharding_spec,
+ communication_pattern=communication_pattern,
+ logical_process_axis=logical_process_axis,
+ ),
+ comm_type=comm_type,
+ arg_index=arg_index,
+ key_for_kwarg=key_for_kwarg,
+ )
def update_communication_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
"""
@@ -155,9 +173,9 @@ def _compute_and_add(op_data: OperationData, comm_spec: CommSpec):
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
for phase, cost in num_ele_in_comm.items():
num_ele_in_comm[phase] = num_ele_in_comm[phase] * size_per_elem_bytes
- comm_cost.fwd += num_ele_in_comm['forward']
- comm_cost.bwd += num_ele_in_comm['backward']
- comm_cost.total += num_ele_in_comm['total']
+ comm_cost.fwd += num_ele_in_comm["forward"]
+ comm_cost.bwd += num_ele_in_comm["backward"]
+ comm_cost.total += num_ele_in_comm["total"]
# check if communication action exists
# if so, loop over each action and compute the cost of each action
@@ -169,8 +187,8 @@ def _compute_and_add(op_data: OperationData, comm_spec: CommSpec):
# this condition branch will be removed after all the handler updated.
comm_spec = comm_action
if isinstance(comm_spec, dict):
- src_spec = comm_spec['src_spec']
- tgt_spec = comm_spec['tgt_spec']
+ src_spec = comm_spec["src_spec"]
+ tgt_spec = comm_spec["tgt_spec"]
shape_consistency_manager = ShapeConsistencyManager()
_, comm_action_sequence, _ = shape_consistency_manager.shape_consistency(src_spec, tgt_spec)
for comm_spec_ in comm_action_sequence:
@@ -187,14 +205,12 @@ def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
"""
Customize this method to compute the computation flops.
"""
- pass
@abstractmethod
def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
"""
Customize this method to compute the memory cost in bytes.
"""
- pass
def _compute_size_in_bytes(self, strategy: ShardingStrategy, key: str):
"""
@@ -212,13 +228,14 @@ def _compute_size_in_bytes_helper(sharding_spec, meta_data):
num_elements = 1
else:
num_elements = reduce(operator.mul, sharded_shape)
- dtype = getattr(meta_data, 'dtype')
+ dtype = getattr(meta_data, "dtype")
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
return num_elements * size_per_elem_bytes
if isinstance(op_data.data, tuple):
- assert isinstance(strategy.sharding_specs[op_data], list), \
- 'sharding_spec of op_data should be a list of sharding specs if op_data.data is a tuple.'
+ assert isinstance(
+ strategy.sharding_specs[op_data], list
+ ), "sharding_spec of op_data should be a list of sharding specs if op_data.data is a tuple."
total_bytes = 0
for index, sharding_spec in enumerate(strategy.sharding_specs[op_data]):
meta_data = op_data.data[index]
@@ -270,7 +287,6 @@ def validate(self) -> bool:
Validate if the operands are of desired shape.
If True, means this generator can be used for the current operation.
"""
- pass
class FollowingStrategyGenerator(StrategyGenerator):
@@ -280,8 +296,9 @@ class FollowingStrategyGenerator(StrategyGenerator):
TODO: remove the original strategy_generator.py after refactoring
"""
- def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh,
- predecessor_node: Node):
+ def __init__(
+ self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, predecessor_node: Node
+ ):
self.op_data = operation_data_mapping
self.device_mesh = device_mesh
self.predecessor_node = predecessor_node
@@ -292,7 +309,8 @@ class OutputStrategyGenerator(StrategyGenerator):
OutputStrategyGenerator is used to generate the sharding strategies for Output Node.
"""
- def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh,
- predecessor_nodes: List[Node]):
+ def __init__(
+ self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, predecessor_nodes: List[Node]
+ ):
super().__init__(operation_data_mapping, device_mesh)
self.predecessor_nodes = predecessor_nodes
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/sum_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/sum_generator.py
index a0fbc58d70c0..dcbf34cfd65b 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/sum_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/sum_generator.py
@@ -4,22 +4,9 @@
from typing import List
from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator
-from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
- CommAction,
- CommType,
- MemoryCost,
- ShardingStrategy,
- TrainCycleItem,
-)
-from colossalai.auto_parallel.tensor_shard.utils import (
- check_keep_sharding_status,
- detect_reshape_mapping,
- infer_output_dim_partition_dict,
-)
-from colossalai.tensor.shape_consistency import CollectiveCommPattern
-from colossalai.tensor.sharding_spec import ShardingSpec
-
-__all__ = ['SumGenerator']
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
+
+__all__ = ["SumGenerator"]
class SumGenerator(FollowingStrategyGenerator):
@@ -31,24 +18,24 @@ def validate(self) -> bool:
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
- sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
- sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
+ sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
+ sharded_output_shape = strategy.sharding_specs[self.op_data["output"]].get_sharded_shape_per_device()
input_size_product = reduce(operator.mul, sharded_input_shape)
output_size_product = reduce(operator.mul, sharded_output_shape)
- compute_cost = TrainCycleItem(fwd=input_size_product,
- bwd=output_size_product,
- total=input_size_product + output_size_product)
+ compute_cost = TrainCycleItem(
+ fwd=input_size_product, bwd=output_size_product, total=input_size_product + output_size_product
+ )
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
- '''
+ """
Compute the memory cost per device with this specific strategy.
- '''
+ """
forward_size_mapping = {
- 'input': self._compute_size_in_bytes(strategy, "input"),
- 'output': self._compute_size_in_bytes(strategy, "output")
+ "input": self._compute_size_in_bytes(strategy, "input"),
+ "output": self._compute_size_in_bytes(strategy, "output"),
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
@@ -66,8 +53,9 @@ def update_memory_cost(self, strategy: ShardingStrategy):
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
- total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
- parameter=fwd_parameter_cost + bwd_parameter_cost)
+ total_mem_cost = MemoryCost(
+ activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost
+ )
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@@ -78,7 +66,7 @@ def collate_strategies(self) -> List[ShardingStrategy]:
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict)
- sum_dims, sum_mapping_dict = self.op_data['sum_info'].data
+ sum_dims, sum_mapping_dict = self.op_data["sum_info"].data
# TODO: a better way to handle the distributed sum is sum all the data on chip and then do all reduce
# among all the shard groups
@@ -90,7 +78,7 @@ def collate_strategies(self) -> List[ShardingStrategy]:
elif dim in sum_mapping_dict:
dim_partition_dict_for_output[sum_mapping_dict[dim]] = dim_partition_dict_for_input[dim]
else:
- raise RuntimeError(f'dim {dim} is not in sum_mapping_dict or sum_dims')
+ raise RuntimeError(f"dim {dim} is not in sum_mapping_dict or sum_dims")
for dim in recover_dims:
dim_partition_dict_for_input.pop(dim)
@@ -105,9 +93,11 @@ def collate_strategies(self) -> List[ShardingStrategy]:
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
strategy_list.append(strategy)
return strategy_list
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/tensor_constructor_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/tensor_constructor_generator.py
index 93cfc9eeea53..eea00c2fa064 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/tensor_constructor_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/tensor_constructor_generator.py
@@ -1,19 +1,10 @@
-import copy
from typing import List
-from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
- CommAction,
- CommType,
- MemoryCost,
- ShardingStrategy,
- TrainCycleItem,
-)
-from colossalai.tensor.shape_consistency import CollectiveCommPattern
-from colossalai.tensor.sharding_spec import ShardingSpec
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
from .strategy_generator import StrategyGenerator
-__all__ = ['TensorConstructorGenerator']
+__all__ = ["TensorConstructorGenerator"]
class TensorConstructorGenerator(StrategyGenerator):
@@ -30,10 +21,10 @@ def update_compute_cost(self, strategy: ShardingStrategy):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
- '''
+ """
Compute the memory cost per device with this specific strategy.
- '''
- forward_size_mapping = {'output': self._compute_size_in_bytes(strategy, "output")}
+ """
+ forward_size_mapping = {"output": self._compute_size_in_bytes(strategy, "output")}
# compute fwd cost incurred
# fwd_cost = input + output
@@ -57,11 +48,13 @@ def collate_strategies(self) -> List[ShardingStrategy]:
communication_action_mapping = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
- name = 'Replica Tensor Constructor'
+ name = "Replica Tensor Constructor"
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
strategy_list.append(strategy)
return strategy_list
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py
index 39799a67c5a0..943cf3f1f50d 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py
@@ -5,7 +5,7 @@
from .strategy_generator import FollowingStrategyGenerator
-__all__ = ['UnaryElementwiseGenerator']
+__all__ = ["UnaryElementwiseGenerator"]
class UnaryElementwiseGenerator(FollowingStrategyGenerator):
@@ -21,12 +21,12 @@ def update_compute_cost(self, strategy: ShardingStrategy):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
- '''
+ """
Compute the memory cost per device with this specific strategy.
- '''
+ """
forward_size_mapping = {
- 'input': self._compute_size_in_bytes(strategy, "input"),
- 'output': self._compute_size_in_bytes(strategy, "output")
+ "input": self._compute_size_in_bytes(strategy, "input"),
+ "output": self._compute_size_in_bytes(strategy, "output"),
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
@@ -44,8 +44,9 @@ def update_memory_cost(self, strategy: ShardingStrategy):
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
- total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
- parameter=fwd_parameter_cost + bwd_parameter_cost)
+ total_mem_cost = MemoryCost(
+ activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost
+ )
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@@ -69,9 +70,11 @@ def collate_strategies(self) -> List[ShardingStrategy]:
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
strategy_list.append(strategy)
return strategy_list
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py
index fa941f2cc51d..b27b4f3d4056 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py
@@ -10,7 +10,7 @@
from .strategy_generator import StrategyGenerator
-__all__ = ['WhereGenerator']
+__all__ = ["WhereGenerator"]
class WhereGenerator(StrategyGenerator):
@@ -26,14 +26,14 @@ def update_compute_cost(self, strategy: ShardingStrategy):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
- '''
+ """
Compute the memory cost per device with this specific strategy.
- '''
+ """
forward_size_mapping = {
- 'condition': self._compute_size_in_bytes(strategy, "condition"),
- 'x': self._compute_size_in_bytes(strategy, "x"),
- 'y': self._compute_size_in_bytes(strategy, "y"),
- 'output': self._compute_size_in_bytes(strategy, "output")
+ "condition": self._compute_size_in_bytes(strategy, "condition"),
+ "x": self._compute_size_in_bytes(strategy, "x"),
+ "y": self._compute_size_in_bytes(strategy, "y"),
+ "output": self._compute_size_in_bytes(strategy, "output"),
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
@@ -59,7 +59,7 @@ def _generate_strategy_with_dim_partition(self, dim_partition):
"condition": dim_partition,
"x": dim_partition,
"y": dim_partition,
- "output": dim_partition
+ "output": dim_partition,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
@@ -67,9 +67,11 @@ def _generate_strategy_with_dim_partition(self, dim_partition):
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["condition"].sharding_sequence} x {sharding_spec_mapping["x"].sharding_sequence} x {sharding_spec_mapping["y"].sharding_sequence}'
communication_action_mapping = {}
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
return strategy
@@ -84,9 +86,9 @@ def enumerate_all_possible_output_spec(self, mesh_dim_0, mesh_dim_1, dimension_l
return dim_partition_list
def collate_strategies(self) -> List[ShardingStrategy]:
- '''
+ """
Generate every possible strategies for a where node, and record all strategies into the strategies_vector.
- '''
+ """
strategy_list = []
dimension_length = len(self.op_data["output"].logical_shape)
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/sum_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/sum_handler.py
index 86f90694e060..5b4ea0afe5f8 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/sum_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/sum_handler.py
@@ -7,7 +7,7 @@
from .registry import operator_registry
from .strategy import StrategyGenerator, SumGenerator
-__all__ = ['SumHandler']
+__all__ = ["SumHandler"]
@operator_registry.register(torch.Tensor.sum)
@@ -55,7 +55,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# sum_mapping_dict[1] = 0 means the 0th dim of output is the 1st dim of input
# sum_mapping_dict[3] = 1 means the 1st dim of output is the 3rd dim of input
sum_mapping_dict = {}
- if 'keepdim' in self.node.kwargs and self.node.kwargs['keepdim']:
+ if "keepdim" in self.node.kwargs and self.node.kwargs["keepdim"]:
for i in range(num_dims):
sum_mapping_dict.update({i: i})
else:
@@ -67,7 +67,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
assert output_index == self.node._meta_data.dim()
sum_info = (sum_dims, sum_mapping_dict)
- physical_shape_operand = OperationData(name='sum_info', type=OperationDataType.ARG, data=sum_info)
+ physical_shape_operand = OperationData(name="sum_info", type=OperationDataType.ARG, data=sum_info)
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
@@ -75,7 +75,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
mapping = {
"input": physical_input_operand,
"sum_info": physical_shape_operand,
- "output": physical_output_operand
+ "output": physical_output_operand,
}
return mapping
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/tensor_constructor_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/tensor_constructor_handler.py
index 855a2e7612af..c2aa120e8a28 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/tensor_constructor_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/tensor_constructor_handler.py
@@ -8,7 +8,7 @@
from .strategy import StrategyGenerator
from .strategy.tensor_constructor_generator import TensorConstructorGenerator
-__all__ = ['TensorConstructorHandler']
+__all__ = ["TensorConstructorHandler"]
@operator_registry.register(torch.arange)
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/transpose_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/transpose_handler.py
index 7a9d37726490..b72d9812f406 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/transpose_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/transpose_handler.py
@@ -7,7 +7,7 @@
from .registry import operator_registry
from .strategy import StrategyGenerator, TransposeGenerator
-__all__ = ['TransposeHandler']
+__all__ = ["TransposeHandler"]
@operator_registry.register(torch.Tensor.transpose)
@@ -48,9 +48,9 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
if transpose_dims[i] < 0:
transpose_dims[i] += num_dims
- physical_shape_operand = OperationData(name='transpose_dims',
- type=OperationDataType.ARG,
- data=list(transpose_dims))
+ physical_shape_operand = OperationData(
+ name="transpose_dims", type=OperationDataType.ARG, data=list(transpose_dims)
+ )
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
@@ -58,7 +58,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
mapping = {
"input": physical_input_operand,
"transpose_dims": physical_shape_operand,
- "output": physical_output_operand
+ "output": physical_output_operand,
}
return mapping
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py
index 0362de780d7a..cbc873de8223 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py
@@ -3,11 +3,11 @@
import torch
from ..sharding_strategy import OperationData, OperationDataType
-from .node_handler import MetaInfoNodeHandler, NodeHandler
+from .node_handler import MetaInfoNodeHandler
from .registry import operator_registry
from .strategy import StrategyGenerator, UnaryElementwiseGenerator
-__all__ = ['UnaryElementwiseHandler']
+__all__ = ["UnaryElementwiseHandler"]
@operator_registry.register(torch.Tensor.to)
@@ -33,9 +33,9 @@ def get_strategy_generator(self) -> List[StrategyGenerator]:
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
- physical_input_operand = OperationData(name=str(self.node.args[0]),
- type=OperationDataType.ARG,
- data=self.node.args[0]._meta_data)
+ physical_input_operand = OperationData(
+ name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data
+ )
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
mapping = {"input": physical_input_operand, "output": physical_output}
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/view_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/view_handler.py
index 7dff89d1d7a3..56c1d10a167e 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/view_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/view_handler.py
@@ -7,7 +7,7 @@
from .registry import operator_registry
from .strategy import StrategyGenerator, ViewGenerator
-__all__ = ['ViewHandler']
+__all__ = ["ViewHandler"]
@operator_registry.register(torch.Tensor.reshape)
@@ -38,7 +38,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data)
target_shape = self.node._meta_data.shape
- physical_shape_operand = OperationData(name='tgt_shape', type=OperationDataType.ARG, data=target_shape)
+ physical_shape_operand = OperationData(name="tgt_shape", type=OperationDataType.ARG, data=target_shape)
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
@@ -46,7 +46,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
mapping = {
"input": physical_input_operand,
"tgt_shape": physical_shape_operand,
- "output": physical_output_operand
+ "output": physical_output_operand,
}
return mapping
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/where_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/where_handler.py
index 6de2aaafdd01..1856a11100b0 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/where_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/where_handler.py
@@ -1,16 +1,15 @@
import copy
-import operator
from typing import Dict, List
import torch
-from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector
+from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from ..utils import recover_sharding_spec_for_broadcast_shape
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import StrategyGenerator, WhereGenerator
-__all__ = ['WhereHandler']
+__all__ = ["WhereHandler"]
@operator_registry.register(torch.where)
@@ -28,27 +27,28 @@ def get_strategy_generator(self) -> List[StrategyGenerator]:
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
- physical_condition_operand = OperationData(name=str(self.node.args[0]),
- type=OperationDataType.ARG,
- data=self.node.args[0]._meta_data)
- physical_x_operand = OperationData(name=str(self.node.args[1]),
- type=OperationDataType.ARG,
- data=self.node.args[1]._meta_data)
- physical_y_operand = OperationData(name=str(self.node.args[2]),
- type=OperationDataType.ARG,
- data=self.node.args[2]._meta_data)
+ physical_condition_operand = OperationData(
+ name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data
+ )
+ physical_x_operand = OperationData(
+ name=str(self.node.args[1]), type=OperationDataType.ARG, data=self.node.args[1]._meta_data
+ )
+ physical_y_operand = OperationData(
+ name=str(self.node.args[2]), type=OperationDataType.ARG, data=self.node.args[2]._meta_data
+ )
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
physical_mapping = {
"condition": physical_condition_operand,
"x": physical_x_operand,
"y": physical_y_operand,
- "output": physical_output
+ "output": physical_output,
}
logical_shape_for_all = self.node._meta_data.shape
logical_mapping = {}
for key, physical_operand in physical_mapping.items():
- logical_mapping[key] = self.convert_physical_operand_to_logical_operand(physical_operand,
- logical_shape_for_all)
+ logical_mapping[key] = self.convert_physical_operand_to_logical_operand(
+ physical_operand, logical_shape_for_all
+ )
return logical_mapping, physical_mapping
@@ -64,7 +64,8 @@ def post_process(self, strategy: ShardingStrategy):
logical_shape = logical_op_data_mapping[key].logical_shape
physical_shape = physical_op_data_mapping[key].logical_shape
physical_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
- logical_sharding_spec, logical_shape, physical_shape)
+ logical_sharding_spec, logical_shape, physical_shape
+ )
strategy.sharding_specs.pop(logical_op_data_mapping[key])
strategy.sharding_specs[physical_op_data_mapping[key]] = physical_sharding_spec
strategy.name = f"{strategy.sharding_specs[physical_op_data_mapping['output']].sharding_sequence} = {strategy.sharding_specs[physical_op_data_mapping['condition']].sharding_sequence} x {strategy.sharding_specs[physical_op_data_mapping['x']].sharding_sequence} x {strategy.sharding_specs[physical_op_data_mapping['y']].sharding_sequence}"
diff --git a/colossalai/auto_parallel/tensor_shard/options.py b/colossalai/auto_parallel/tensor_shard/options.py
index f0ea502a6f0e..e87872f39c10 100644
--- a/colossalai/auto_parallel/tensor_shard/options.py
+++ b/colossalai/auto_parallel/tensor_shard/options.py
@@ -1,13 +1,14 @@
from dataclasses import dataclass
from enum import Enum
-__all__ = ['SolverOptions', 'SolverPerference', 'DataloaderOption', 'ShardOption']
+__all__ = ["SolverOptions", "SolverPerference", "DataloaderOption", "ShardOption"]
class SolverPerference(Enum):
"""
This enum class is to define the solver preference.
"""
+
STANDARD = 0
DP = 1
TP = 2
@@ -25,6 +26,7 @@ class ShardOption(Enum):
TP_SHARD: We require the node to be shard using tensor parallel strategies on last device mesh axis.
TP_FULL_SHARD: We require the node to be shard using tensor parallel strategies on all device mesh axes.
"""
+
STANDARD = 0
SHARD = 1
SHARD_LAST_AXIS = 2
@@ -35,6 +37,7 @@ class DataloaderOption(Enum):
"""
This enum class is to define the dataloader option.
"""
+
REPLICATED = 0
DISTRIBUTED = 1
@@ -44,6 +47,7 @@ class SolverOptions:
"""
SolverOptions is a dataclass used to configure the preferences for the parallel execution plan search.
"""
+
solver_perference: SolverPerference = SolverPerference.STANDARD
dataloader_option: DataloaderOption = DataloaderOption.REPLICATED
shard_option: ShardOption = ShardOption.STANDARD
diff --git a/colossalai/auto_parallel/tensor_shard/sharding_strategy.py b/colossalai/auto_parallel/tensor_shard/sharding_strategy.py
index 6af927272437..8e22df64d868 100644
--- a/colossalai/auto_parallel/tensor_shard/sharding_strategy.py
+++ b/colossalai/auto_parallel/tensor_shard/sharding_strategy.py
@@ -10,7 +10,6 @@
from colossalai.tensor.sharding_spec import ShardingSpec
from .constants import (
- BCAST_FUNC_OP,
ELEMENTWISE_FUNC_OP,
ELEMENTWISE_METHOD_OP,
ELEMENTWISE_MODULE_OP,
@@ -18,13 +17,14 @@
RESHAPE_METHOD_OP,
)
-__all__ = ['OperationDataType', 'OperationData', 'TrainCycleItem', 'MemoryCost', 'ShardingStrategy', 'StrategiesVector']
+__all__ = ["OperationDataType", "OperationData", "TrainCycleItem", "MemoryCost", "ShardingStrategy", "StrategiesVector"]
class OperationDataType(Enum):
"""
An operation can come from the argument list of an operator or the parameter list of a module.
"""
+
INPUT = 0
ARG = 1
PARAM = 2
@@ -43,6 +43,7 @@ class OperationData:
data (Any): the value for this data, usually it is a meta tensor.
logical_shape (Tuple[int]): the logical shape of the data, it can be different from the its actual shape in memory.
"""
+
name: str
type: OperationDataType
data: Any
@@ -69,13 +70,13 @@ def _infer_logical_shape(data: any):
self.logical_shape = _infer_logical_shape(self.data)
def __repr__(self) -> str:
- return f'OperationData(name={self.name}, type={self.type})'
+ return f"OperationData(name={self.name}, type={self.type})"
def __eq__(self, other) -> bool:
return other.name == self.name
def __hash__(self) -> int:
- return hash(f'{self.name}')
+ return hash(f"{self.name}")
@dataclass
@@ -88,6 +89,7 @@ class TrainCycleItem:
fwd (float): the item for the forward pass
bwd (float): the item for the backward pass
"""
+
fwd: Any
bwd: Any
total: Any
@@ -104,6 +106,7 @@ class MemoryCost:
temp (int): the memory cost incurred by the temporary tensors in bytes.
buffer (int): the memory cost incurred by the module buffer in bytes.
"""
+
activation: int = 0
parameter: int = 0
temp: int = 0
@@ -120,6 +123,7 @@ class CommType(Enum):
HOOK: the communication action is used to do the grad all reduce.
IMPLICIT: the communication action happens during the kernel execution, such as SyncBatchNorm
"""
+
BEFORE = 0
AFTER = 1
HOOK = 2
@@ -137,6 +141,7 @@ class CommAction:
arg_index: record the location of tensor which join the communication, we cannot use name of node or op_data at runtime,
because the args of node may be changed by graph transform passes.
"""
+
comm_spec: CommSpec = None
comm_type: CommType = None
arg_index: int = -1
@@ -156,6 +161,7 @@ class ShardingStrategy:
memory_cost (TrainCycleItem): Memory cost of the output node using this strategy. (default to None)
input_sharding_specs (List(ShardingSpec)): The ShardingSpecs of the input nodes.
"""
+
name: str
sharding_specs: Dict[OperationData, Union[ShardingSpec, Tuple[ShardingSpec]]] = None
compute_cost: TrainCycleItem = None
@@ -200,7 +206,6 @@ def get_sharding_spec_by_name(self, name: str):
raise KeyError(f"Could not find the ShardingSpec for OperationData with name {name}")
def clone(self):
-
def _deepcopy_dict_vals(data: Dict):
return {k: deepcopy(v) for k, v in data.items()}
@@ -209,31 +214,34 @@ def _deepcopy_dict_vals(data: Dict):
# Consider the examples below:
# If self.communication_actions is an empty dictionary {}, then self.communication_actions is not None, but its __bool__ value is False.
# In this case, if we set None to the new object, program will crash when we try to access the communication_actions.items.
- communication_actions = _deepcopy_dict_vals(
- self.communication_actions) if self.communication_actions is not None else None
+ communication_actions = (
+ _deepcopy_dict_vals(self.communication_actions) if self.communication_actions is not None else None
+ )
# same reason as communication_actions
resharding_costs = _deepcopy_dict_vals(self.resharding_costs) if self.resharding_costs is not None else None
compute_cost = deepcopy(self.compute_cost)
communication_cost = deepcopy(self.communication_cost)
memory_cost = deepcopy(self.memory_cost)
- return ShardingStrategy(name=self.name,
- sharding_specs=sharding_specs,
- compute_cost=compute_cost,
- communication_cost=communication_cost,
- memory_cost=memory_cost,
- communication_actions=communication_actions,
- resharding_costs=resharding_costs)
+ return ShardingStrategy(
+ name=self.name,
+ sharding_specs=sharding_specs,
+ compute_cost=compute_cost,
+ communication_cost=communication_cost,
+ memory_cost=memory_cost,
+ communication_actions=communication_actions,
+ resharding_costs=resharding_costs,
+ )
class StrategiesVector(list):
- '''
+ """
Each node in fx graph will have a corresponding StrategiesVector, to store all the possible
strategies of the node.
Argument:
node (Node): node for which the list of sharding strategies are generated.
- '''
+ """
def __init__(self, node: Node):
super().__init__()
@@ -245,7 +253,7 @@ def __init__(self, node: Node):
def check_merge(self):
merge_label = False
- if self.node.op == 'call_module':
+ if self.node.op == "call_module":
target = self.node.target
root_module = self.node.graph.owning_module
submod = root_module.get_submodule(target)
@@ -255,7 +263,7 @@ def check_merge(self):
if submod_type in ELEMENTWISE_MODULE_OP:
merge_label = True
- if self.node.op == 'call_function':
+ if self.node.op == "call_function":
# we could merge element-wise op, because the output sharding spec is always same as the input sharding spec.
if self.node.target in ELEMENTWISE_FUNC_OP:
merge_label = True
@@ -267,7 +275,7 @@ def check_merge(self):
if self.node.target in RESHAPE_FUNC_OP:
merge_label = True
- if self.node.op == 'call_method':
+ if self.node.op == "call_method":
# we could merge reshape op, because their computation costs are negligible.
method = getattr(self.node.args[0]._meta_data.__class__, self.node.target)
if method in RESHAPE_METHOD_OP:
diff --git a/colossalai/auto_parallel/tensor_shard/solver/__init__.py b/colossalai/auto_parallel/tensor_shard/solver/__init__.py
index f9e6bd923921..b930ce80a9b9 100644
--- a/colossalai/auto_parallel/tensor_shard/solver/__init__.py
+++ b/colossalai/auto_parallel/tensor_shard/solver/__init__.py
@@ -3,4 +3,4 @@
from .solver import Solver
from .strategies_constructor import StrategiesConstructor
-__all__ = ['GraphAnalyser', 'Solver', 'StrategiesConstructor', 'CostGraph']
+__all__ = ["GraphAnalyser", "Solver", "StrategiesConstructor", "CostGraph"]
diff --git a/colossalai/auto_parallel/tensor_shard/solver/cost_graph.py b/colossalai/auto_parallel/tensor_shard/solver/cost_graph.py
index 1b2d3ad57407..4415d429b0c2 100644
--- a/colossalai/auto_parallel/tensor_shard/solver/cost_graph.py
+++ b/colossalai/auto_parallel/tensor_shard/solver/cost_graph.py
@@ -4,7 +4,7 @@
class CostGraph:
- '''
+ """
A graph data structure to simplify the edge cost graph. It has two main functions:
1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in
CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list.
@@ -15,7 +15,7 @@ class CostGraph:
Argument:
leaf_strategies(List[StrategiesVector]): It stores StrategiesVector of every nodes on the graph.
simplify(bool, optional): The generated cost graph will be simplified if it is true. (default to True)
- '''
+ """
def __init__(self, leaf_strategies, simplify=True, forward_only=False):
self.leaf_strategies = leaf_strategies
@@ -39,10 +39,10 @@ def _remove_invalid_node(self, node, attr_name):
target_node_list.remove(element)
def _build_cost_graph(self):
- '''
+ """
This method will generate edge_cost for adjacent node pair. Additionally, 'parents' and 'children' attribute will be
set to node.
- '''
+ """
self.edge_costs = {}
if self.simplify:
self.merge_pair = []
@@ -84,8 +84,8 @@ def _check_tensor_in_node(data):
if _check_tensor_in_node(node._meta_data):
children_nodes.append(node)
- setattr(dst_node, 'parents', parent_nodes)
- setattr(dst_node, 'children', children_nodes)
+ setattr(dst_node, "parents", parent_nodes)
+ setattr(dst_node, "children", children_nodes)
if self.simplify and strategies_vector.check_merge():
for followed_node in strategies_vector.predecessor_nodes:
@@ -99,7 +99,7 @@ def get_edge_cost(self, src_node, dst_node):
return self.edge_costs[(src_node, dst_node)]
def merge_node(self, src_node, dst_node):
- '''
+ """
To merge dst_node into src_node, we need to do it in following steps:
1. For each strategy in dst_node, we need to pick an appropriate strategy
@@ -119,7 +119,7 @@ def merge_node(self, src_node, dst_node):
Argument:
src_node(Node): The node will be merged into dst_node.
dst_node(Node): The node to integrate src_node.
- '''
+ """
# build merge_map
merge_map = {}
for src_index, _ in enumerate(src_node.strategies_vector):
@@ -196,7 +196,7 @@ def simplify_graph(self):
if not self.simplify:
return
self.merge_pair.reverse()
- for (src_node, dst_node) in self.merge_pair:
+ for src_node, dst_node in self.merge_pair:
self.merge_node(src_node, dst_node)
self.merge_pair.reverse()
reindexing_following_dict = {}
diff --git a/colossalai/auto_parallel/tensor_shard/solver/graph_analysis.py b/colossalai/auto_parallel/tensor_shard/solver/graph_analysis.py
index 171aa8b3399f..678965d663e4 100644
--- a/colossalai/auto_parallel/tensor_shard/solver/graph_analysis.py
+++ b/colossalai/auto_parallel/tensor_shard/solver/graph_analysis.py
@@ -7,7 +7,7 @@
from colossalai.fx.passes.utils import get_node_module
-__all__ = ['LiveVariable', 'LiveVariableVector', 'LiveStage', 'GraphAnalyser']
+__all__ = ["LiveVariable", "LiveVariableVector", "LiveStage", "GraphAnalyser"]
@dataclass
@@ -15,6 +15,7 @@ class LiveVariable:
"""
LiveVariable is a data structure to store the meta information of a variable for liveness analysis.
"""
+
name: str
node: Node
is_inplace: bool
@@ -55,6 +56,7 @@ class LiveStage:
"""
LiveStage is a data structure to record the living variables at this current node.
"""
+
name: str
node: Node
all_live_vars: LiveVariableVector
@@ -62,7 +64,6 @@ class LiveStage:
class GraphAnalyser:
-
def __init__(self, gm: GraphModule):
self._gm = gm
self._graph = gm.graph
@@ -105,18 +106,18 @@ def liveness_analysis(self) -> List[LiveStage]:
# detect whether the current op is an in-place op
# if it is an in-place op, we would deem it as a duplicate var
is_inplace = False
- if node.op == 'call_function':
+ if node.op == "call_function":
# check if this is an inplace op such as torch.nn.functional.relu(x, inplace=True)
- if node.kwargs.get('inplace', False):
+ if node.kwargs.get("inplace", False):
is_inplace = True
- elif node.op == 'call_module':
+ elif node.op == "call_module":
# to check if this is an inplace op such as torch.nn.Relu(inplace=True)
module = get_node_module(node)
- if getattr(module, 'inplace', False):
+ if getattr(module, "inplace", False):
is_inplace = True
# add the output var
- meta = getattr(node, '_meta_data', None)
+ getattr(node, "_meta_data", None)
live_var = LiveVariable(name=node.name, node=node, is_inplace=is_inplace)
if not is_inplace:
unique_live_vars.append(live_var)
@@ -138,10 +139,12 @@ def liveness_analysis(self) -> List[LiveStage]:
# this should be completed if we are able to trace the backward compute graph
# add this stage to liveness dict
- stage = LiveStage(name=node.name,
- node=node,
- all_live_vars=all_live_variables.copy(),
- unique_live_vars=unique_live_vars.copy())
+ stage = LiveStage(
+ name=node.name,
+ node=node,
+ all_live_vars=all_live_variables.copy(),
+ unique_live_vars=unique_live_vars.copy(),
+ )
# if a LiveStage is covered by another LiveStage, we just keep the larger one.
replace = False
for index, prev_stage in enumerate(liveness_list):
diff --git a/colossalai/auto_parallel/tensor_shard/solver/solver.py b/colossalai/auto_parallel/tensor_shard/solver/solver.py
index 564c5f09220c..088d1acb5177 100644
--- a/colossalai/auto_parallel/tensor_shard/solver/solver.py
+++ b/colossalai/auto_parallel/tensor_shard/solver/solver.py
@@ -21,24 +21,25 @@
import pulp
from pulp import LpMinimize, LpProblem, LpStatus, LpVariable, lpDot, lpSum
except:
- warnings.warn(f'please install the pulp')
+ warnings.warn(f"please install the pulp")
-__all___ = ['Solver']
+__all___ = ["Solver"]
class Solver:
-
- def __init__(self,
- graph: Graph,
- strategies_constructor: StrategiesConstructor,
- cost_graph: CostGraph,
- graph_analyser: GraphAnalyser = None,
- memory_budget: float = -1.0,
- solution_numbers: int = 1,
- forward_only: bool = False,
- memory_increasing_coefficient: float = 1.3,
- verbose=False):
- '''
+ def __init__(
+ self,
+ graph: Graph,
+ strategies_constructor: StrategiesConstructor,
+ cost_graph: CostGraph,
+ graph_analyser: GraphAnalyser = None,
+ memory_budget: float = -1.0,
+ solution_numbers: int = 1,
+ forward_only: bool = False,
+ memory_increasing_coefficient: float = 1.3,
+ verbose=False,
+ ):
+ """
Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph.
Argument:
graph: The computing graph to be optimized.
@@ -48,7 +49,7 @@ def __init__(self,
memory_budget: Memory constraint for the solution.
solution_numbers: If solution_numbers is larger than one, solver will us a serious of solutions based on different memory budget.
memory_increasing_coefficient: If solution_numbers is larger than one, we will use this coefficient to generate new memory budget.
- '''
+ """
self.graph = graph
self.strategies_constructor = strategies_constructor
self.cost_graph = cost_graph
@@ -75,11 +76,11 @@ def __init__(self,
self.verbose = verbose
def _recover_merged_node_strategy(self):
- '''
+ """
During cost graph constructing, some nodes, such as unary element-wise node or ReshapeOp, were merged into the previous node.
Therefore, the index of those strategies are copied from the previous node. This method is used to recover the strategy index of those merged
node.
- '''
+ """
for node_index, node in enumerate(self.nodes):
if node.strategies_vector.check_merge():
# the merged node has only one input, and its strategies follow the input sharding strategy
@@ -98,9 +99,9 @@ def _generate_node_index_dict(self) -> Dict[Node, int]:
return node_index_dict
def _prepare_data_for_solver(self):
- '''
+ """
Extract information from components for solver.
- '''
+ """
node_nums = len(self.leaf_strategies)
memory_budget = self.memory_budget
@@ -190,23 +191,40 @@ def _prepare_data_for_solver(self):
# omit initial value for nodes
s_init_np = None
- return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, alias_convert_costs, s_init_np, self.verbose
-
- def _call_solver_serialized_args(self,
- node_nums,
- memory_budget,
- strategies_len,
- following_nodes,
- edge_pairs,
- alias_set,
- liveness_set,
- compute_costs,
- communication_costs,
- memory_costs,
- resharding_costs,
- alias_convert_costs,
- s_init_np=None,
- verbose=True):
+ return (
+ node_nums,
+ memory_budget,
+ strategies_len,
+ following_nodes,
+ edge_pairs,
+ alias_set,
+ liveness_set,
+ compute_costs,
+ communication_costs,
+ memory_costs,
+ resharding_costs,
+ alias_convert_costs,
+ s_init_np,
+ self.verbose,
+ )
+
+ def _call_solver_serialized_args(
+ self,
+ node_nums,
+ memory_budget,
+ strategies_len,
+ following_nodes,
+ edge_pairs,
+ alias_set,
+ liveness_set,
+ compute_costs,
+ communication_costs,
+ memory_costs,
+ resharding_costs,
+ alias_convert_costs,
+ s_init_np=None,
+ verbose=True,
+ ):
"""
Call the solver with serialized arguments.
"""
@@ -235,18 +253,18 @@ def get_non_zero_index(binary_vector):
s_follow = following_nodes
s_alias = alias_set
- E = edge_pairs.reshape((-1, 2)) # noqa
+ E = edge_pairs.reshape((-1, 2)) # noqa
r = []
pt = 0
edge_set = set()
- for (i, j) in E:
+ for i, j in E:
prod_length = strategies_len[i] * strategies_len[j]
if (i, j) in edge_set:
raise ValueError(f"Duplicated edges: {(i, j)}")
edge_set.add((i, j))
- r.append(resharding_costs[pt:pt + prod_length])
+ r.append(resharding_costs[pt : pt + prod_length])
pt += prod_length
assert pt == len(resharding_costs)
@@ -268,7 +286,6 @@ def get_non_zero_index(binary_vector):
# L.append(liveness_set[pt:pt + length])
# pt += length
# assert pt == len(liveness_set)
- v = []
pt = 0
c = []
@@ -277,9 +294,9 @@ def get_non_zero_index(binary_vector):
pt = 0
for i in range(node_nums):
length = strategies_len[i]
- c.append(compute_costs[pt:pt + length])
- d.append(communication_costs[pt:pt + length])
- m.append(memory_costs[pt:pt + length])
+ c.append(compute_costs[pt : pt + length])
+ d.append(communication_costs[pt : pt + length])
+ m.append(memory_costs[pt : pt + length])
pt += length
assert pt == len(compute_costs), f"{pt} == {len(compute_costs)}"
assert pt == len(communication_costs), f"{pt} == {len(communication_costs)}"
@@ -319,7 +336,7 @@ def get_non_zero_index(binary_vector):
e = []
num_edges = 0
map_edge_to_idx = {}
- for (idx, (i, j)) in enumerate(E):
+ for idx, (i, j) in enumerate(E):
if len(s[i]) == 1:
e.append(s[j])
elif len(s[j]) == 1:
@@ -340,7 +357,7 @@ def get_non_zero_index(binary_vector):
######################################
if s_init_np is not None:
s_init = s_init_np.reshape((-1, 3))
- for (idx, value, fix) in s_init:
+ for idx, value, fix in s_init:
for i in range(len(s[idx])):
s[idx][i].setInitialValue(i == value)
if fix:
@@ -393,7 +410,7 @@ def get_non_zero_index(binary_vector):
# (d). specified by `cat="Binary"`
- for (idx, (i, j)) in enumerate(E):
+ for idx, (i, j) in enumerate(E):
if strategies_len[i] == 1 or strategies_len[j] == 1:
continue
@@ -402,13 +419,13 @@ def get_non_zero_index(binary_vector):
# (f)
for row in range(len(s[i])):
- C = len(s[j]) # noqa
+ C = len(s[j]) # noqa
prob += lpSum(e[idx][row * C + col] for col in range(0, C)) <= s[i][row]
# (g)
for col in range(len(s[j])):
- R = len(s[i]) # noqa
- C = len(s[j]) # noqa
+ R = len(s[i]) # noqa
+ C = len(s[j]) # noqa
prob += lpSum(e[idx][row * C + col] for row in range(0, R)) <= s[j][col]
# (h)
@@ -434,7 +451,8 @@ def get_non_zero_index(binary_vector):
msg = verbose
time_limit = 600
assert "COIN_CMD" in pulp.listSolvers(
- onlyAvailable=True), ("Please install ILP solvers by 'sudo apt install coinor-cbc'")
+ onlyAvailable=True
+ ), "Please install ILP solvers by 'sudo apt install coinor-cbc'"
solver = pulp.COIN_CMD(mip=True, msg=msg, timeLimit=time_limit, threads=multiprocessing.cpu_count())
# solver = pulp.GLPK_CMD(mip=True, msg=msg, timeLimit=time_limit)
@@ -444,13 +462,13 @@ def get_non_zero_index(binary_vector):
objective = pulp.value(prob.objective)
objective = float(objective) if objective is not None else -1.0
if verbose:
- print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}\t"
- f"Time: {time.time() - tic}")
+ print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}\t" f"Time: {time.time() - tic}")
print(f"#nodes: {num_nodes}, #edges: {num_edges}")
if prob.status in [pulp.LpStatusInfeasible]:
- raise RuntimeError("Cannot run the function under the given memory budget. "
- "Please increase the memory budget.")
+ raise RuntimeError(
+ "Cannot run the function under the given memory budget. " "Please increase the memory budget."
+ )
# Get and check results
s_val = np.full((node_nums,), -1, dtype=np.int32)
@@ -458,7 +476,7 @@ def get_non_zero_index(binary_vector):
s_val[i] = get_non_zero_index(s[i])
e_val = np.full((len(E),), -1, dtype=np.int32)
- for (idx, (i, j)) in enumerate(E):
+ for idx, (i, j) in enumerate(E):
e_val[idx] = get_non_zero_index(e[idx])
i_spec_index = e_val[idx] // len(s[j])
j_spec_index = e_val[idx] % len(s[j])
diff --git a/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py b/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py
index 044a8ac847ea..aa87ee9bf3db 100644
--- a/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py
+++ b/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py
@@ -1,11 +1,5 @@
-import builtins
-import math
-import operator
-from copy import deepcopy
-from typing import Dict, List
-
import torch
-from torch.fx import Graph, Node
+from torch.fx import Graph
from colossalai.auto_parallel.tensor_shard.node_handler import (
GetattrHandler,
@@ -14,13 +8,12 @@
operator_registry,
)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector
-from colossalai.auto_parallel.tensor_shard.utils import generate_resharding_costs, generate_sharding_spec
from colossalai.auto_parallel.tensor_shard.utils.factory import find_repeat_blocks
from colossalai.device.device_mesh import DeviceMesh
from ..options import DataloaderOption, SolverOptions
-__all__ = ['StrategiesConstructor']
+__all__ = ["StrategiesConstructor"]
class StrategiesConstructor:
@@ -35,7 +28,7 @@ class StrategiesConstructor:
def __init__(self, graph: Graph, device_mesh: DeviceMesh, solver_options: SolverOptions):
self.graph = graph
- assert graph.owning_module is not None, 'The given graph is not associated with a owning_module'
+ assert graph.owning_module is not None, "The given graph is not associated with a owning_module"
self.root_module = self.graph.owning_module
self.nodes = list(graph.nodes)
self.device_mesh = device_mesh
@@ -46,11 +39,11 @@ def __init__(self, graph: Graph, device_mesh: DeviceMesh, solver_options: Solver
self.alias_set = None
def remove_duplicated_strategy(self, strategies_vector):
- '''
+ """
In build_strategies_and_cost method, we may produce some duplicated strategies.
In this method, we will remove the duplicated strategies depending on the strategies name.
Note that this operation is in-place.
- '''
+ """
name_checklist = []
remove_list = []
for strategy in strategies_vector:
@@ -62,7 +55,6 @@ def remove_duplicated_strategy(self, strategies_vector):
strategies_vector.remove(strategy)
def generate_alias_set(self):
-
node_list = [strategy_vector.node for strategy_vector in self.leaf_strategies]
common_blocks = find_repeat_blocks(node_list, self.root_module, common_length_threshold=10)
@@ -83,7 +75,7 @@ def build_strategies_and_cost(self):
"""
def _check_no_strategy_for_node(node):
- if node.op in ('placeholder', 'get_attr', 'output'):
+ if node.op in ("placeholder", "get_attr", "output"):
return False
def _check_no_strategy_for_data(data):
@@ -102,83 +94,93 @@ def _check_no_strategy_for_data(data):
if _check_no_strategy_for_node(node):
self.no_strategy_nodes.append(node)
- pass
# placeholder node
- elif node.op == 'placeholder':
+ elif node.op == "placeholder":
if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED:
- placeholder_option = 'distributed'
+ placeholder_option = "distributed"
else:
- assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported'
- placeholder_option = 'replicated'
- placeholder_handler = PlaceholderHandler(node,
- self.device_mesh,
- strategies_vector,
- placeholder_option=placeholder_option)
+ assert (
+ self.solver_options.dataloader_option == DataloaderOption.REPLICATED
+ ), f"placeholder_option {self.solver_options.dataloader_option} is not supported"
+ placeholder_option = "replicated"
+ placeholder_handler = PlaceholderHandler(
+ node, self.device_mesh, strategies_vector, placeholder_option=placeholder_option
+ )
placeholder_handler.register_strategy()
# get_attr node
- elif node.op == 'get_attr':
- getattr_handler = GetattrHandler(node,
- self.device_mesh,
- strategies_vector,
- shard_option=self.solver_options.shard_option,
- solver_perference=self.solver_options.solver_perference)
+ elif node.op == "get_attr":
+ getattr_handler = GetattrHandler(
+ node,
+ self.device_mesh,
+ strategies_vector,
+ shard_option=self.solver_options.shard_option,
+ solver_perference=self.solver_options.solver_perference,
+ )
getattr_handler.register_strategy()
# call_module node
- elif node.op == 'call_module':
+ elif node.op == "call_module":
target = node.target
submod = self.root_module.get_submodule(target)
submod_type = type(submod)
- handler = operator_registry.get(submod_type)(node,
- self.device_mesh,
- strategies_vector,
- shard_option=self.solver_options.shard_option,
- solver_perference=self.solver_options.solver_perference)
+ handler = operator_registry.get(submod_type)(
+ node,
+ self.device_mesh,
+ strategies_vector,
+ shard_option=self.solver_options.shard_option,
+ solver_perference=self.solver_options.solver_perference,
+ )
handler.register_strategy()
# attach strategies_info to node
- if hasattr(handler, 'strategies_info'):
- setattr(node, 'strategies_info', handler.strategies_info)
+ if hasattr(handler, "strategies_info"):
+ setattr(node, "strategies_info", handler.strategies_info)
# call_function node
- elif node.op == 'call_function':
+ elif node.op == "call_function":
target = node.target
- handler = operator_registry.get(target)(node,
- self.device_mesh,
- strategies_vector,
- shard_option=self.solver_options.shard_option,
- solver_perference=self.solver_options.solver_perference)
+ handler = operator_registry.get(target)(
+ node,
+ self.device_mesh,
+ strategies_vector,
+ shard_option=self.solver_options.shard_option,
+ solver_perference=self.solver_options.solver_perference,
+ )
handler.register_strategy()
# attach strategies_info to node
- if hasattr(handler, 'strategies_info'):
- setattr(node, 'strategies_info', handler.strategies_info)
+ if hasattr(handler, "strategies_info"):
+ setattr(node, "strategies_info", handler.strategies_info)
# call_method node
- elif node.op == 'call_method':
+ elif node.op == "call_method":
method = getattr(node.args[0]._meta_data.__class__, node.target)
- handler = operator_registry.get(method)(node,
- self.device_mesh,
- strategies_vector,
- shard_option=self.solver_options.shard_option,
- solver_perference=self.solver_options.solver_perference)
+ handler = operator_registry.get(method)(
+ node,
+ self.device_mesh,
+ strategies_vector,
+ shard_option=self.solver_options.shard_option,
+ solver_perference=self.solver_options.solver_perference,
+ )
handler.register_strategy()
# attach strategies_info to node
- if hasattr(handler, 'strategies_info'):
- setattr(node, 'strategies_info', handler.strategies_info)
+ if hasattr(handler, "strategies_info"):
+ setattr(node, "strategies_info", handler.strategies_info)
# output node
- elif node.op == 'output':
+ elif node.op == "output":
if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED:
- output_option = 'distributed'
+ output_option = "distributed"
else:
- assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported'
- output_option = 'replicated'
+ assert (
+ self.solver_options.dataloader_option == DataloaderOption.REPLICATED
+ ), f"placeholder_option {self.solver_options.dataloader_option} is not supported"
+ output_option = "replicated"
output_handler = OutputHandler(node, self.device_mesh, strategies_vector, output_option=output_option)
output_handler.register_strategy()
self.remove_duplicated_strategy(strategies_vector)
- setattr(node, 'strategies_vector', strategies_vector)
+ setattr(node, "strategies_vector", strategies_vector)
self.leaf_strategies.append(strategies_vector)
self.strategy_map[node] = strategies_vector
diff --git a/colossalai/auto_parallel/tensor_shard/utils/__init__.py b/colossalai/auto_parallel/tensor_shard/utils/__init__.py
index b7fe5430bf13..d61cfd2add15 100644
--- a/colossalai/auto_parallel/tensor_shard/utils/__init__.py
+++ b/colossalai/auto_parallel/tensor_shard/utils/__init__.py
@@ -17,9 +17,21 @@
)
__all__ = [
- 'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape',
- 'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity'
- 'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
- 'enumerate_all_possible_2d_sharding', 'generate_sharding_size', 'comm_actions_for_oprands', 'pytree_map',
- 'detect_reshape_mapping', 'check_keep_sharding_status', 'infer_output_dim_partition_dict'
+ "BroadcastType",
+ "get_broadcast_shape",
+ "is_broadcastable",
+ "recover_sharding_spec_for_broadcast_shape",
+ "generate_resharding_costs",
+ "generate_sharding_spec",
+ "ignore_sharding_exception",
+ "check_sharding_spec_validity" "transpose_partition_dim",
+ "update_partition_dim",
+ "enumerate_all_possible_1d_sharding",
+ "enumerate_all_possible_2d_sharding",
+ "generate_sharding_size",
+ "comm_actions_for_oprands",
+ "pytree_map",
+ "detect_reshape_mapping",
+ "check_keep_sharding_status",
+ "infer_output_dim_partition_dict",
]
diff --git a/colossalai/auto_parallel/tensor_shard/utils/broadcast.py b/colossalai/auto_parallel/tensor_shard/utils/broadcast.py
index 307348ea1eaf..99d5a0f2a942 100644
--- a/colossalai/auto_parallel/tensor_shard/utils/broadcast.py
+++ b/colossalai/auto_parallel/tensor_shard/utils/broadcast.py
@@ -14,8 +14,11 @@
from colossalai.tensor.sharding_spec import ShardingSpec
__all__ = [
- 'BroadcastType', 'is_broadcastable', 'get_broadcast_shape', 'recover_sharding_spec_for_broadcast_shape',
- 'comm_actions_for_oprands'
+ "BroadcastType",
+ "is_broadcastable",
+ "get_broadcast_shape",
+ "recover_sharding_spec_for_broadcast_shape",
+ "comm_actions_for_oprands",
]
@@ -41,7 +44,7 @@ def get_broadcast_shape(shape1: torch.Size, shape2: torch.Size) -> List[int]:
"""
Compute the broadcast shape given two shapes.
"""
- assert is_broadcastable(shape1, shape2), f'{shape1} and {shape2} are not broadcastable'
+ assert is_broadcastable(shape1, shape2), f"{shape1} and {shape2} are not broadcastable"
shape1_reverse = shape1[::-1]
shape2_reverse = shape2[::-1]
min_common_dim = min(len(shape1), len(shape2))
@@ -60,8 +63,9 @@ def get_broadcast_dim_info(logical_shape, physical_shape):
logical_num_dims = len(logical_shape)
physical_num_dims = len(physical_shape)
- assert logical_num_dims >= physical_num_dims, \
- 'The number of dimensions in the logical shape is smaller than that of the physical shape, this tensor is not broadcast!'
+ assert (
+ logical_num_dims >= physical_num_dims
+ ), "The number of dimensions in the logical shape is smaller than that of the physical shape, this tensor is not broadcast!"
# track the dim and its broadcasting type
logical_dim_broadcast_info = {}
@@ -85,8 +89,9 @@ def get_broadcast_dim_info(logical_shape, physical_shape):
return logical_dim_broadcast_info
-def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpec, logical_shape: torch.Size,
- physical_shape: torch.Size) -> ShardingSpec:
+def recover_sharding_spec_for_broadcast_shape(
+ logical_sharding_spec: ShardingSpec, logical_shape: torch.Size, physical_shape: torch.Size
+) -> ShardingSpec:
"""
This function computes the sharding spec for the physical shape of a broadcast tensor.
@@ -124,15 +129,18 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe
physical_dim = physical_num_dims - (logical_num_dims - shape_dim)
physical_dim_partition[physical_dim] = mesh_dim
- physical_sharding_spec = ShardingSpec(device_mesh=logical_sharding_spec.device_mesh,
- entire_shape=physical_shape,
- dim_partition_dict=physical_dim_partition)
+ physical_sharding_spec = ShardingSpec(
+ device_mesh=logical_sharding_spec.device_mesh,
+ entire_shape=physical_shape,
+ dim_partition_dict=physical_dim_partition,
+ )
return physical_sharding_spec, removed_dims
-def comm_actions_for_oprands(node: Node, removed_dims: List[int], op_data: OperationData,
- sharding_spec: ShardingSpec) -> CommAction:
+def comm_actions_for_oprands(
+ node: Node, removed_dims: List[int], op_data: OperationData, sharding_spec: ShardingSpec
+) -> CommAction:
"""
This method is used to generate communication actions for oprands which lose information
during convert logical shape to physical shape.
@@ -140,9 +148,11 @@ def comm_actions_for_oprands(node: Node, removed_dims: List[int], op_data: Opera
if len(removed_dims) == 1:
# if list length is 1, extract element from list to avoid using flatten device mesh
removed_dims = removed_dims[0]
- comm_spec = CommSpec(comm_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
- sharding_spec=sharding_spec,
- logical_process_axis=removed_dims)
+ comm_spec = CommSpec(
+ comm_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
+ sharding_spec=sharding_spec,
+ logical_process_axis=removed_dims,
+ )
if op_data.type == OperationDataType.PARAM:
comm_type = CommType.HOOK
else:
@@ -151,7 +161,7 @@ def comm_actions_for_oprands(node: Node, removed_dims: List[int], op_data: Opera
for index, arg in enumerate(node.args):
if op_data.name == str(arg):
arg_index = index
- assert arg_index >= 0, f'op_data should be an argument of node.'
+ assert arg_index >= 0, f"op_data should be an argument of node."
comm_action = CommAction(
comm_spec=comm_spec,
comm_type=comm_type,
diff --git a/colossalai/auto_parallel/tensor_shard/utils/factory.py b/colossalai/auto_parallel/tensor_shard/utils/factory.py
index 347c10aa102d..aaca923a5eee 100644
--- a/colossalai/auto_parallel/tensor_shard/utils/factory.py
+++ b/colossalai/auto_parallel/tensor_shard/utils/factory.py
@@ -14,11 +14,12 @@
from ..constants import INFINITY_COST
-__all__ = ['generate_sharding_spec', 'generate_resharding_costs']
+__all__ = ["generate_sharding_spec", "generate_resharding_costs"]
-def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh,
- dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
+def generate_sharding_spec(
+ input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh, dim_partition_dict: Dict[int, List[int]]
+) -> ShardingSpec:
"""
Generate the sharding spec of the tensor based on the given dim_partition_dict.
@@ -30,7 +31,7 @@ def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: Devic
"""
if isinstance(input_, Node):
- assert hasattr(input_, '_meta_data'), f'The given node has no attribute _meta_data'
+ assert hasattr(input_, "_meta_data"), f"The given node has no attribute _meta_data"
meta_tensor = input_._meta_data
assert meta_tensor is not None, "The given node's _meta_data attribute is None"
shape = meta_tensor.shape
@@ -38,24 +39,27 @@ def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: Devic
shape = input_.shape
else:
raise TypeError(
- f'We cannot generate sharding spec for {type(input_)} type, only torch.fx.Node or torch.Tensor is expected.'
+ f"We cannot generate sharding spec for {type(input_)} type, only torch.fx.Node or torch.Tensor is expected."
)
for dim_index, sharding_index_list in dim_partition_dict.items():
sharding_list = [device_mesh.mesh_shape[sharding_index] for sharding_index in sharding_index_list]
sharding_size = reduce(operator.mul, sharding_list, 1)
- assert shape[
- dim_index] % sharding_size == 0, f'we cannot shard the {dim_index} dimension of tensor into {sharding_size} partitions.'
+ assert (
+ shape[dim_index] % sharding_size == 0
+ ), f"we cannot shard the {dim_index} dimension of tensor into {sharding_size} partitions."
sharding_spec = ShardingSpec(device_mesh=device_mesh, entire_shape=shape, dim_partition_dict=dim_partition_dict)
return sharding_spec
-def generate_resharding_costs(nodes: List[Node],
- sharding_specs: List[ShardingSpec],
- count_backward: Optional[bool] = True,
- dtype: Optional[torch.dtype] = None,
- index=None):
- '''
+def generate_resharding_costs(
+ nodes: List[Node],
+ sharding_specs: List[ShardingSpec],
+ count_backward: Optional[bool] = True,
+ dtype: Optional[torch.dtype] = None,
+ index=None,
+):
+ """
Compute the resharding costs with this specific strategy.
Argument:
@@ -63,7 +67,7 @@ def generate_resharding_costs(nodes: List[Node],
sharding_spec_for_input(ShardingSpec): a list of ShardingSpec for the nodes.
count_backward (Optional[bool]): whether to include the cost of resharding in the backward pass, default is True. False can be used for inference.
dtype (Optional[torch.dtype]): the data type for cost calculation, default is None.
- '''
+ """
# The resharding_cost of weight is counted due to sharing weight cases.
resharding_costs = {}
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
@@ -76,38 +80,39 @@ def generate_resharding_costs(nodes: List[Node],
for strategy in input_node.strategies_vector:
input_sharding_spec = strategy.output_sharding_spec
if not isinstance(input_sharding_spec, ShardingSpec):
- assert isinstance(input_sharding_spec, list), 'only ShardingSpec or List[ShardingSpec] is expected.'
+ assert isinstance(input_sharding_spec, list), "only ShardingSpec or List[ShardingSpec] is expected."
input_sharding_spec = input_sharding_spec[index]
- assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
+ assert isinstance(input_sharding_spec, ShardingSpec), f"The input node should NOT be a tuple of tensor."
try:
# compute the resharding cost
_, _, total_resharding_cost = shape_consistency_manager.shape_consistency(
- input_sharding_spec, input_spec)
+ input_sharding_spec, input_spec
+ )
# we need multiply the size of elem dtype to get correct communication cost
resharding_cost = total_resharding_cost["total"] * size_per_elem_bytes
except AssertionError as e:
- warnings.warn(f'{e}')
+ warnings.warn(f"{e}")
resharding_cost = INFINITY_COST
resharding_costs[input_node].append(resharding_cost)
return resharding_costs
def find_repeat_blocks(node_list: List[torch.fx.Node], root_module, common_length_threshold: int = 20):
- '''
+ """
Find the largest repeat blocks in the graph, whose length is larger than the threshold.
Args:
gm (GraphModule): the graph module to be analyzed.
common_length_threshold (int): the threshold of the repeat block length.
- '''
+ """
# graph = gm.graph
def _process_args(args):
new_args = []
for arg in args:
- if hasattr(arg, '_meta_data'):
+ if hasattr(arg, "_meta_data"):
meta_data = arg._meta_data
else:
meta_data = arg
@@ -145,7 +150,7 @@ def _check_node_equal(node1, node2):
return False
for index, node in enumerate(node_list):
- if node.op == 'call_module':
+ if node.op == "call_module":
target = node.target
submod = root_module.get_submodule(target)
submod_type = type(submod)
@@ -155,12 +160,12 @@ def _check_node_equal(node1, node2):
new_args = _process_args(node.args)
- if node.op != 'get_attr':
+ if node.op != "get_attr":
hash_key = (node.op, target, *new_args)
else:
hash_key = (node.op,)
- setattr(node, 'hash_key', hash_key)
+ setattr(node, "hash_key", hash_key)
hash_value_to_node_dict = {}
@@ -179,7 +184,7 @@ def _check_node_equal(node1, node2):
# the comparison will be triggered if a common node appears
if len(hash_value_to_node_dict[hash(node.hash_key)]) >= 2:
start_index_list = hash_value_to_node_dict[hash(node.hash_key)]
- check_block_list = [node_list[start:start + max_common_length] for start in start_index_list]
+ check_block_list = [node_list[start : start + max_common_length] for start in start_index_list]
common_label = True
if not _all_equal(check_block_list, _check_node_list_equal):
@@ -201,6 +206,6 @@ def _check_node_equal(node1, node2):
# recover common subgraph from the index
common_blocks = []
for start in common_blocks_index:
- common_blocks.append(node_list[start:start + max_common_length])
+ common_blocks.append(node_list[start : start + max_common_length])
return common_blocks
diff --git a/colossalai/auto_parallel/tensor_shard/utils/misc.py b/colossalai/auto_parallel/tensor_shard/utils/misc.py
index 475e95fc4326..42ec2a8ee428 100644
--- a/colossalai/auto_parallel/tensor_shard/utils/misc.py
+++ b/colossalai/auto_parallel/tensor_shard/utils/misc.py
@@ -1,12 +1,12 @@
import functools
-from typing import Any, Callable, Dict, List, Tuple, Type, Union
+from typing import Any, Callable, Tuple, Type, Union
import torch
from colossalai.logging import get_dist_logger
from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException
-__all__ = ['ignore_sharding_exception', 'pytree_map']
+__all__ = ["ignore_sharding_exception", "pytree_map"]
def ignore_sharding_exception(func):
@@ -48,29 +48,32 @@ def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tens
tensor_num_dim = tensor.dim()
num_devices_in_col = sharding_spec.device_mesh.shape[0]
num_devices_in_row = sharding_spec.device_mesh.shape[1]
- assert sharding_len == tensor_num_dim, \
- f'The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape}).'
+ assert (
+ sharding_len == tensor_num_dim
+ ), f"The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape})."
# make sure the sharding is valid for each dim
for i in range(tensor_num_dim):
dim_size = tensor.shape[i]
dim_spec = sharding_spec.sharding_sequence[i]
- if str(dim_spec).startswith('S'):
- devices_str = str(dim_spec).lstrip('S')
+ if str(dim_spec).startswith("S"):
+ devices_str = str(dim_spec).lstrip("S")
num_devices = 1
- if '0' in devices_str:
+ if "0" in devices_str:
num_devices *= num_devices_in_col
- if '1' in devices_str:
+ if "1" in devices_str:
num_devices *= num_devices_in_row
- assert dim_size >= num_devices and dim_size % num_devices == 0, \
- f'The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices.'
+ assert (
+ dim_size >= num_devices and dim_size % num_devices == 0
+ ), f"The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices."
# make sure the entire shape matches the physical tensor shape
- assert sharding_spec.entire_shape == tensor.shape, \
- f'The entire_shape of the sharding spec {sharding_spec.entire_shape} does not match the tensor shape {tensor.shape}'
+ assert (
+ sharding_spec.entire_shape == tensor.shape
+ ), f"The entire_shape of the sharding spec {sharding_spec.entire_shape} does not match the tensor shape {tensor.shape}"
def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any:
diff --git a/colossalai/auto_parallel/tensor_shard/utils/reshape.py b/colossalai/auto_parallel/tensor_shard/utils/reshape.py
index d0ebbd7e8b1b..329312ef797f 100644
--- a/colossalai/auto_parallel/tensor_shard/utils/reshape.py
+++ b/colossalai/auto_parallel/tensor_shard/utils/reshape.py
@@ -8,6 +8,7 @@ class PreviousStatus(Enum):
"""
This class shows the status of previous comparison.
"""
+
RESET = 0
# ORIGIN means the dimension size of original tensor is larger in the previous comparison.
ORIGIN = 1
@@ -130,8 +131,9 @@ def detect_reshape_mapping(origin_shape: torch.Size, tgt_shape: torch.Size) -> D
return reshape_mapping_dict
-def check_keep_sharding_status(input_dim_partition_dict: Dict[int, List[int]],
- reshape_mapping_dict: Dict[Tuple[int], Tuple[int]]) -> bool:
+def check_keep_sharding_status(
+ input_dim_partition_dict: Dict[int, List[int]], reshape_mapping_dict: Dict[Tuple[int], Tuple[int]]
+) -> bool:
"""
This method is used to check whether the reshape operation could implement without converting
the input to fully replicated status.
@@ -172,14 +174,16 @@ def check_keep_sharding_status(input_dim_partition_dict: Dict[int, List[int]],
return True
-def infer_output_dim_partition_dict(input_dim_partition_dict: Dict[int, List[int]],
- reshape_mapping_dict: Dict[Tuple[int], Tuple[int]]) -> Dict[Tuple[int], Tuple[int]]:
+def infer_output_dim_partition_dict(
+ input_dim_partition_dict: Dict[int, List[int]], reshape_mapping_dict: Dict[Tuple[int], Tuple[int]]
+) -> Dict[Tuple[int], Tuple[int]]:
"""
This method is used to infer the output dim partition dict for a reshape operation,
given the input dim partition dict and reshape mapping dict.
"""
- assert check_keep_sharding_status(input_dim_partition_dict, reshape_mapping_dict), \
- 'we only infer output dim partition dict for the reshape operation could keep sharding spec.'
+ assert check_keep_sharding_status(
+ input_dim_partition_dict, reshape_mapping_dict
+ ), "we only infer output dim partition dict for the reshape operation could keep sharding spec."
sharded_dims = list(input_dim_partition_dict.keys())
output_dim_partition_dict = {}
for input_dims, output_dims in reshape_mapping_dict.items():
diff --git a/colossalai/auto_parallel/tensor_shard/utils/sharding.py b/colossalai/auto_parallel/tensor_shard/utils/sharding.py
index e2ce59e0b577..b5386d599be4 100644
--- a/colossalai/auto_parallel/tensor_shard/utils/sharding.py
+++ b/colossalai/auto_parallel/tensor_shard/utils/sharding.py
@@ -8,8 +8,11 @@
from colossalai.tensor.sharding_spec import ShardingSpec
__all__ = [
- 'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
- 'enumerate_all_possible_2d_sharding', 'generate_sharding_size'
+ "transpose_partition_dim",
+ "update_partition_dim",
+ "enumerate_all_possible_1d_sharding",
+ "enumerate_all_possible_2d_sharding",
+ "generate_sharding_size",
]
@@ -22,8 +25,7 @@ def transpose_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -
dim1 (int): the tensor dimension to switch
dim2 (int): the tensor dimension to switch
"""
- assert len(sharding_spec.entire_shape) >= 2, \
- 'The entire_shape of the sharding spec must have at least 2 dimensions'
+ assert len(sharding_spec.entire_shape) >= 2, "The entire_shape of the sharding spec must have at least 2 dimensions"
dim_partition_dict = sharding_spec.dim_partition_dict
# transpose the dim partition
@@ -45,10 +47,9 @@ def transpose_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -
return sharding_spec
-def update_partition_dim(sharding_spec: ShardingSpec,
- dim_mapping: Dict[int, int],
- physical_shape: torch.Size,
- inplace: bool = False):
+def update_partition_dim(
+ sharding_spec: ShardingSpec, dim_mapping: Dict[int, int], physical_shape: torch.Size, inplace: bool = False
+):
"""
This method is used to update the partition dim dict from the logical one to the physical one.
@@ -78,9 +79,9 @@ def update_partition_dim(sharding_spec: ShardingSpec,
new_dim_partition_dict[tensor_dim] = mesh_dims
# update sharding spec
- current_sharding_spec.__init__(device_mesh=sharding_spec.device_mesh,
- entire_shape=physical_shape,
- dim_partition_dict=new_dim_partition_dict)
+ current_sharding_spec.__init__(
+ device_mesh=sharding_spec.device_mesh, entire_shape=physical_shape, dim_partition_dict=new_dim_partition_dict
+ )
return current_sharding_spec
diff --git a/colossalai/autochunk/autochunk_codegen.py b/colossalai/autochunk/autochunk_codegen.py
index cc98c1570b4a..9571fa2c17f0 100644
--- a/colossalai/autochunk/autochunk_codegen.py
+++ b/colossalai/autochunk/autochunk_codegen.py
@@ -9,7 +9,18 @@
AUTOCHUNK_AVAILABLE = CODEGEN_AVAILABLE and is_compatible_with_meta()
if AUTOCHUNK_AVAILABLE:
- from torch.fx.graph import CodeGen, PythonCode, _custom_builtins, _CustomBuiltin, _format_target, _is_from_torch, _Namespace, _origin_type_map, inplace_methods, magic_methods
+ from torch.fx.graph import (
+ CodeGen,
+ PythonCode,
+ _custom_builtins,
+ _CustomBuiltin,
+ _format_target,
+ _is_from_torch,
+ _Namespace,
+ _origin_type_map,
+ inplace_methods,
+ magic_methods,
+ )
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
@@ -64,14 +75,21 @@ def _gen_loop_start(chunk_input: List[Node], chunk_output: List[Node], chunk_out
for i in range(len(chunk_output)):
shape_str = str(list(get_node_shape(chunk_output[i])))
if get_node_name(chunk_output[i]) in ["split", "unbind"]:
- tensor_str = "torch.empty(%s, dtype=%s.dtype, device=%s.device), " % (shape_str, input_node.name,
- input_node.name)
- tensor_str = tensor_str * len(chunk_output[i].meta['tensor_meta'])
+ tensor_str = "torch.empty(%s, dtype=%s.dtype, device=%s.device), " % (
+ shape_str,
+ input_node.name,
+ input_node.name,
+ )
+ tensor_str = tensor_str * len(chunk_output[i].meta["tensor_meta"])
tensor_str = "[" + tensor_str[:-2] + "]"
context += "%s = %s; " % (chunk_output[i].name, tensor_str)
else:
- context += "%s = torch.empty(%s, dtype=%s.dtype, device=%s.device); " % (chunk_output[i].name, shape_str,
- input_node.name, input_node.name)
+ context += "%s = torch.empty(%s, dtype=%s.dtype, device=%s.device); " % (
+ chunk_output[i].name,
+ shape_str,
+ input_node.name,
+ input_node.name,
+ )
out_shape = get_node_shape(chunk_output[0])
chunk_shape = out_shape[chunk_output_dim[0]]
@@ -79,8 +97,14 @@ def _gen_loop_start(chunk_input: List[Node], chunk_output: List[Node], chunk_out
return context
-def _gen_loop_end(chunk_inputs: List[Node], chunk_non_compute_inputs: List[Node], node_list: List[Node],
- chunk_outputs_idx: int, chunk_outputs_non_tensor: List[Node], search_chunk: SearchChunk) -> str:
+def _gen_loop_end(
+ chunk_inputs: List[Node],
+ chunk_non_compute_inputs: List[Node],
+ node_list: List[Node],
+ chunk_outputs_idx: int,
+ chunk_outputs_non_tensor: List[Node],
+ search_chunk: SearchChunk,
+) -> str:
"""
Generate chunk loop end
@@ -148,8 +172,10 @@ def _replace_new_tensor_like_shape(
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"]
if get_node_shape(meta_node)[chunk_dim] != 1:
source_node = meta_node.args[0].args[0]
- if (source_node not in chunk_infos[region_idx]["node_chunk_dim"]
- or chunk_infos[region_idx]["node_chunk_dim"][source_node]["chunk_dim"] is None):
+ if (
+ source_node not in chunk_infos[region_idx]["node_chunk_dim"]
+ or chunk_infos[region_idx]["node_chunk_dim"][source_node]["chunk_dim"] is None
+ ):
chunk_slice = _gen_chunk_slice_dim(chunk_dim, "chunk_idx", get_node_shape(node))
body[-1] = _replace_name(body[-1], node.args[0].name, node.args[0].name + chunk_slice)
return body
@@ -203,11 +229,12 @@ def _add_node_slice(
# outputs node
else:
if chunk_node.name == node.name or (chunk_node.name in [i.name for i in node.all_input_nodes]):
- chunk_slice = _gen_chunk_slice_dim(chunk_nodes_dim[region_idx][chunk_node_idx], "chunk_idx",
- get_node_shape(chunk_node))
+ chunk_slice = _gen_chunk_slice_dim(
+ chunk_nodes_dim[region_idx][chunk_node_idx], "chunk_idx", get_node_shape(chunk_node)
+ )
if get_node_name(chunk_node) in ["split", "unbind"]:
split_chunk_slice = ""
- for i in range(len(chunk_node.meta['tensor_meta'])):
+ for i in range(len(chunk_node.meta["tensor_meta"])):
split_chunk_slice += "%s[%d]%s, " % (chunk_node.name, i, chunk_slice)
split_chunk_slice = split_chunk_slice[:-2]
body[-1] = _replace_name(body[-1], chunk_node.name, split_chunk_slice)
@@ -216,13 +243,15 @@ def _add_node_slice(
return body
-def emit_code_with_chunk(body: List[str],
- nodes: Iterable[Node],
- emit_node_func: Callable,
- delete_unused_value_func: Callable,
- search_chunk: SearchChunk,
- chunk_infos: List,
- eval_mem: bool = False):
+def emit_code_with_chunk(
+ body: List[str],
+ nodes: Iterable[Node],
+ emit_node_func: Callable,
+ delete_unused_value_func: Callable,
+ search_chunk: SearchChunk,
+ chunk_infos: List,
+ eval_mem: bool = False,
+):
"""
Emit code with chunk according to chunk_infos.
@@ -244,9 +273,9 @@ def emit_code_with_chunk(body: List[str],
chunk_ends = [i["region"][1] for i in chunk_infos]
# chunk inputs
- chunk_inputs = [i["inputs"] for i in chunk_infos] # input with chunk
- chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] # input without chunk
- chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] # input chunk dim
+ chunk_inputs = [i["inputs"] for i in chunk_infos] # input with chunk
+ chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] # input without chunk
+ chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] # input chunk dim
chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [j.name for i in chunk_inputs_non_chunk for j in i]
# chunk outputs
@@ -275,7 +304,8 @@ def emit_code_with_chunk(body: List[str],
chunk_outputs[region_idx],
chunk_outputs_dim[region_idx],
chunk_infos[region_idx]["chunk_size"],
- ))
+ )
+ )
if within_chunk_region:
emit_node_func(node, body)
@@ -294,7 +324,8 @@ def emit_code_with_chunk(body: List[str],
if eval_mem:
body.append(
" if chunk_idx == 0:\n print('%s', torch.cuda.max_memory_allocated() / 1024**2 - init_memory); torch.cuda.reset_peak_memory_stats()\n"
- % (node.name))
+ % (node.name)
+ )
else:
emit_node_func(node, body)
if node_idx not in chunk_inputs:
@@ -302,13 +333,21 @@ def emit_code_with_chunk(body: List[str],
if eval_mem:
body.append(
"print('%s', torch.cuda.max_memory_allocated() / 1024**2 - init_memory); torch.cuda.reset_peak_memory_stats()\n"
- % (node.name))
+ % (node.name)
+ )
# generate chunk region end
if node_idx in chunk_ends:
body.append(
- _gen_loop_end(chunk_inputs[region_idx], chunk_inputs_non_chunk[region_idx], node_list,
- chunk_ends[region_idx], chunk_outputs_non_tensor[region_idx], search_chunk))
+ _gen_loop_end(
+ chunk_inputs[region_idx],
+ chunk_inputs_non_chunk[region_idx],
+ node_list,
+ chunk_ends[region_idx],
+ chunk_outputs_non_tensor[region_idx],
+ search_chunk,
+ )
+ )
within_chunk_region = False
node_idx += 1
@@ -317,13 +356,14 @@ def emit_code_with_chunk(body: List[str],
if AUTOCHUNK_AVAILABLE:
class AutoChunkCodeGen(CodeGen):
-
- def __init__(self,
- meta_graph,
- max_memory: int = None,
- print_mem: bool = False,
- print_progress: bool = False,
- eval_mem: bool = False) -> None:
+ def __init__(
+ self,
+ meta_graph,
+ max_memory: int = None,
+ print_mem: bool = False,
+ print_progress: bool = False,
+ eval_mem: bool = False,
+ ) -> None:
super().__init__()
self.eval_mem = eval_mem
# find the chunk regions
@@ -349,7 +389,7 @@ def add_global(name_hint: str, obj: Any):
Returns: the global name that should be used to reference 'obj' in generated source.
"""
- if (_is_from_torch(obj) and obj != torch.device): # to support registering torch.device
+ if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
# HACK: workaround for how torch custom ops are registered. We
# can't import them like normal modules so they must retain their
# fully qualified name.
@@ -402,7 +442,6 @@ def type_repr(o: Any):
return add_global(typename, o)
def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:
-
def _get_repr(arg):
# Handle NamedTuples (if it has `_fields`) via add_global.
if isinstance(arg, tuple) and hasattr(arg, "_fields"):
@@ -457,10 +496,10 @@ def delete_unused_values(user: Node, body, to_keep=[]):
# NOTE: we add a variable to distinguish body and ckpt_func
def emit_node(node: Node, body):
- maybe_type_annotation = ("" if node.type is None else f" : {type_repr(node.type)}")
+ maybe_type_annotation = "" if node.type is None else f" : {type_repr(node.type)}"
if node.op == "placeholder":
assert isinstance(node.target, str)
- maybe_default_arg = ("" if not node.args else f" = {repr(node.args[0])}")
+ maybe_default_arg = "" if not node.args else f" = {repr(node.args[0])}"
free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}")
raw_name = node.target.replace("*", "")
if raw_name != repr(node):
@@ -470,42 +509,56 @@ def emit_node(node: Node, body):
assert isinstance(node.target, str)
body.append(
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}"
- f"({_format_args(node.args[1:], node.kwargs)})")
+ f"({_format_args(node.args[1:], node.kwargs)})"
+ )
return
elif node.op == "call_function":
assert callable(node.target)
# pretty print operators
- if (node.target.__module__ == "_operator" and node.target.__name__ in magic_methods):
+ if node.target.__module__ == "_operator" and node.target.__name__ in magic_methods:
assert isinstance(node.args, tuple)
- body.append(f"{repr(node)}{maybe_type_annotation} = "
- f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}")
+ body.append(
+ f"{repr(node)}{maybe_type_annotation} = "
+ f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}"
+ )
return
# pretty print inplace operators; required for jit.script to work properly
# not currently supported in normal FX graphs, but generated by torchdynamo
- if (node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods):
- body.append(f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; "
- f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}")
+ if node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods:
+ body.append(
+ f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; "
+ f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}"
+ )
return
qualified_name = _get_qualified_name(node.target)
global_name = add_global(qualified_name, node.target)
# special case for getattr: node.args could be 2-argument or 3-argument
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
- if (global_name == "getattr" and isinstance(node.args, tuple) and isinstance(node.args[1], str)
- and node.args[1].isidentifier() and len(node.args) == 2):
+ if (
+ global_name == "getattr"
+ and isinstance(node.args, tuple)
+ and isinstance(node.args[1], str)
+ and node.args[1].isidentifier()
+ and len(node.args) == 2
+ ):
body.append(
- f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}")
+ f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}"
+ )
return
body.append(
- f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})")
+ f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})"
+ )
if node.meta.get("is_wrapped", False):
wrapped_fns.setdefault(global_name)
return
elif node.op == "call_module":
assert isinstance(node.target, str)
- body.append(f"{repr(node)}{maybe_type_annotation} = "
- f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})")
+ body.append(
+ f"{repr(node)}{maybe_type_annotation} = "
+ f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})"
+ )
return
elif node.op == "get_attr":
assert isinstance(node.target, str)
@@ -523,8 +576,9 @@ def emit_node(node: Node, body):
# if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen
- emit_code_with_chunk(body, nodes, emit_node, delete_unused_values, self.search_chunk, self.chunk_infos,
- self.eval_mem)
+ emit_code_with_chunk(
+ body, nodes, emit_node, delete_unused_values, self.search_chunk, self.chunk_infos, self.eval_mem
+ )
if len(body) == 0:
# If the Graph has no non-placeholder nodes, no lines for the body
diff --git a/colossalai/autochunk/estimate_memory.py b/colossalai/autochunk/estimate_memory.py
index 77bc2ef17bc3..a85ad429e261 100644
--- a/colossalai/autochunk/estimate_memory.py
+++ b/colossalai/autochunk/estimate_memory.py
@@ -1,11 +1,8 @@
-import copy
-from typing import Any, Callable, Dict, Iterable, List, Tuple
+from typing import Dict, List
import torch
from torch.fx.node import Node
-from colossalai.fx.profiler import activation_size, parameter_size
-
from .utils import NodeMgr, get_node_shape, is_non_memory_node
@@ -62,12 +59,9 @@ def _build_delete_node_dict(self, node_mgr: NodeMgr) -> Dict:
delete_node_dict[node] = max(node_user_idx)
return delete_node_dict
- def _remove_deactive_node(self,
- user_idx: int,
- user: Node,
- active_nodes: List,
- delete_node_dict: List,
- kept_nodes: List = None) -> None:
+ def _remove_deactive_node(
+ self, user_idx: int, user: Node, active_nodes: List, delete_node_dict: List, kept_nodes: List = None
+ ) -> None:
"""
remove deactivate nodes from active nodes
"""
@@ -169,7 +163,7 @@ def estimate_chunk_inference_mem(self, node_list: List, chunk_infos: Dict = None
use_chunk = True if chunk_infos is not None else False
chunk_within = False
chunk_region_idx = None
- chunk_ratio = 1 # use it to estimate chunk mem
+ chunk_ratio = 1 # use it to estimate chunk mem
chunk_inputs_all = []
if use_chunk:
@@ -184,7 +178,6 @@ def estimate_chunk_inference_mem(self, node_list: List, chunk_infos: Dict = None
chunk_sizes = [i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos]
for idx, node in enumerate(node_mgr.get_node_list()):
-
# if node in chunk start nodes, change chunk ratio and add chunk_tensor
if use_chunk and idx in chunk_starts:
chunk_within = True
@@ -193,8 +186,9 @@ def estimate_chunk_inference_mem(self, node_list: List, chunk_infos: Dict = None
# determine chunk ratio for current node
if chunk_within:
- chunk_ratio = self._get_chunk_ratio(node, chunk_node_dim[chunk_region_idx],
- chunk_sizes[chunk_region_idx])
+ chunk_ratio = self._get_chunk_ratio(
+ node, chunk_node_dim[chunk_region_idx], chunk_sizes[chunk_region_idx]
+ )
# add current node as active node
self._add_active_node(node, active_nodes, chunk_ratio)
@@ -222,7 +216,7 @@ def estimate_chunk_inference_mem(self, node_list: List, chunk_infos: Dict = None
# if node in chunk end nodes, restore chunk settings
if use_chunk and idx in chunk_ends:
- self._remove_deactive_node(idx, node, active_nodes, delete_node_dict) # dont provide kept nodes now
+ self._remove_deactive_node(idx, node, active_nodes, delete_node_dict) # dont provide kept nodes now
chunk_within = False
chunk_ratio = 1
chunk_region_idx = None
diff --git a/colossalai/autochunk/search_chunk.py b/colossalai/autochunk/search_chunk.py
index 59645c80e808..1c599049d9eb 100644
--- a/colossalai/autochunk/search_chunk.py
+++ b/colossalai/autochunk/search_chunk.py
@@ -8,7 +8,7 @@
from .select_chunk import SelectChunk
from .trace_flow import TraceFlow
from .trace_indice import TraceIndice
-from .utils import NodeMgr, get_logger, get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder
+from .utils import NodeMgr, get_logger, is_non_compute_node, is_non_compute_node_except_placeholder
class SearchChunk(object):
@@ -121,8 +121,10 @@ def _search_max_chunk_region(self, active_node: List, peak_region: int, chunk_re
# check if peak node already in chunk info
if chunk_regions is not None:
for i in chunk_regions:
- if i["region"][0] < peak_region[0] <= i["region"][1] or \
- i["region"][0] < peak_region[1] <= i["region"][1]:
+ if (
+ i["region"][0] < peak_region[0] <= i["region"][1]
+ or i["region"][0] < peak_region[1] <= i["region"][1]
+ ):
return None
active_node_num = [len(i) for i in active_node]
@@ -146,9 +148,9 @@ def _search_max_chunk_region(self, active_node: List, peak_region: int, chunk_re
region = i["region"]
if chunk_region_start >= region[0] and chunk_region_end <= region[1]:
return None
- elif (region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]):
+ elif region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]:
chunk_region_start = region[1] + 1
- elif (region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]):
+ elif region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]:
chunk_region_end = region[0] - 1
return chunk_region_start, chunk_region_end
@@ -171,7 +173,7 @@ def _find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> Lis
chunk_infos: possible regions found
"""
start_traces = input_trace[start_idx]
- if len(start_traces) > 1: # TODO need to be removed
+ if len(start_traces) > 1: # TODO need to be removed
return []
end_trace = output_trace[end_idx]
end_node = self.node_mgr.get_node_by_idx(end_idx)
@@ -180,8 +182,9 @@ def _find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> Lis
for end_dim, _ in enumerate(end_trace["indice"]):
for start_node, start_trace in start_traces.items():
for start_dim, _ in enumerate(start_trace["indice"]):
- if not self.trace_flow.check_region_start_end(start_node, start_dim, start_idx, end_node, end_dim,
- end_idx):
+ if not self.trace_flow.check_region_start_end(
+ start_node, start_dim, start_idx, end_node, end_dim, end_idx
+ ):
continue
# flow search
chunk_info = self.trace_flow.flow_search(start_idx, start_dim, end_idx, end_dim)
@@ -203,7 +206,7 @@ def _search_possible_chunk_regions(self, max_chunk_region: Tuple, peak_region: N
"""
possible_chunk_region = []
output_trace = copy.deepcopy(self.trace_indice.indice_trace_list)
- input_trace = [] # trace of a node's input nodes
+ input_trace = [] # trace of a node's input nodes
for _, n in enumerate(self.node_mgr.get_node_list()):
cur_trace = {}
for arg in n.args:
@@ -215,7 +218,8 @@ def _search_possible_chunk_regions(self, max_chunk_region: Tuple, peak_region: N
for end_idx in range(peak_region[1], max_chunk_region[1] + 1):
# skip non compute nodes
if is_non_compute_node(self.node_mgr.get_node_by_idx(start_idx)) or is_non_compute_node(
- self.node_mgr.get_node_by_idx(end_idx)):
+ self.node_mgr.get_node_by_idx(end_idx)
+ ):
continue
# select free dim
chunk_info = self._find_chunk_info(input_trace, output_trace, start_idx, end_idx)
@@ -279,15 +283,18 @@ def search_region(self) -> Dict:
chunk_infos.append(chunk_info)
mem_peak, _, active_node = self.estimate_memory.estimate_chunk_inference_mem(
- self.node_mgr.get_node_list(), chunk_infos)
+ self.node_mgr.get_node_list(), chunk_infos
+ )
if self.print_progress:
- get_logger().info("AutoChunk find chunk region %d = (%d, %d)" %
- (len(chunk_infos), chunk_info["region"][0], chunk_info["region"][1]))
+ get_logger().info(
+ "AutoChunk find chunk region %d = (%d, %d)"
+ % (len(chunk_infos), chunk_info["region"][0], chunk_info["region"][1])
+ )
if self.print_mem:
self.print_mem = False
- self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list(),
- chunk_infos,
- print_mem=True)
+ self.estimate_memory.estimate_chunk_inference_mem(
+ self.node_mgr.get_node_list(), chunk_infos, print_mem=True
+ )
return chunk_infos
diff --git a/colossalai/autochunk/select_chunk.py b/colossalai/autochunk/select_chunk.py
index 94a29bfd5691..8a60ba681f70 100644
--- a/colossalai/autochunk/select_chunk.py
+++ b/colossalai/autochunk/select_chunk.py
@@ -5,7 +5,6 @@
class SelectChunk(object):
-
def __init__(
self,
trace_indice: TraceIndice,
@@ -20,7 +19,7 @@ def __init__(
self.node_mgr = node_mgr
if max_memory is not None:
self.stratge = "fit_memory"
- self.max_memory = max_memory # MB
+ self.max_memory = max_memory # MB
else:
self.stratge = "min_memory"
@@ -57,16 +56,18 @@ def _select_fit_memory_chunk_region(self, possible_chunk_regions, chunk_infos, m
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region)
cur_chunk_infos = chunk_infos + [cur_region]
cur_mem = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
- cur_chunk_region_peak = cur_mem[cur_region["region"][0]:cur_region["region"][1] + 1]
+ cur_chunk_region_peak = cur_mem[cur_region["region"][0] : cur_region["region"][1] + 1]
cur_chunk_region_max_peak = max(cur_chunk_region_peak)
if cur_chunk_region_max_peak < self.max_memory:
- regions_dict.append({
- "chunk_info": region,
- "chunk_max_mem": cur_chunk_region_max_peak,
- "chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]),
- "reorder_chunk_info": cur_region,
- "reorder_node_list": cur_node_list,
- })
+ regions_dict.append(
+ {
+ "chunk_info": region,
+ "chunk_max_mem": cur_chunk_region_max_peak,
+ "chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]),
+ "reorder_chunk_info": cur_region,
+ "reorder_node_list": cur_node_list,
+ }
+ )
# no region found
if len(regions_dict) == 0:
raise RuntimeError("Search failed. Try a larger memory threshold.")
@@ -90,13 +91,15 @@ def _get_fit_chunk_size(self, chunk_region_dict, chunk_infos):
chunk_size *= 2
reorder_chunk_info["chunk_size"] = chunk_size
cur_chunk_infos = chunk_infos + [reorder_chunk_info]
- cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(chunk_region_dict["reorder_node_list"],
- cur_chunk_infos)[0]
- cur_chunk_max_mem = max(cur_mem_peak[reorder_chunk_info["region"][0]:reorder_chunk_info["region"][1] + 1])
+ cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
+ chunk_region_dict["reorder_node_list"], cur_chunk_infos
+ )[0]
+ cur_chunk_max_mem = max(cur_mem_peak[reorder_chunk_info["region"][0] : reorder_chunk_info["region"][1] + 1])
# search exact size
chunk_info = chunk_region_dict["chunk_info"]
- chunk_info["chunk_size"] = self._chunk_size_binary_search(chunk_size // 2, chunk_size, chunk_region_dict,
- chunk_infos)
+ chunk_info["chunk_size"] = self._chunk_size_binary_search(
+ chunk_size // 2, chunk_size, chunk_region_dict, chunk_infos
+ )
return chunk_info
def _chunk_size_binary_search(self, left, right, chunk_region_dict, chunk_infos):
@@ -109,9 +112,10 @@ def _chunk_size_binary_search(self, left, right, chunk_region_dict, chunk_infos)
mid = int((left + right) / 2 + 0.5)
chunk_info["chunk_size"] = mid
cur_chunk_infos = chunk_infos + [chunk_info]
- cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(chunk_region_dict["reorder_node_list"],
- cur_chunk_infos)[0]
- cur_chunk_max_mem = max(cur_mem_peak[chunk_info["region"][0]:chunk_info["region"][1] + 1])
+ cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
+ chunk_region_dict["reorder_node_list"], cur_chunk_infos
+ )[0]
+ cur_chunk_max_mem = max(cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1])
if cur_chunk_max_mem >= self.max_memory:
right = mid - gap
else:
@@ -139,8 +143,10 @@ def _select_min_memory_chunk_region(self, possible_chunk_regions, chunk_infos):
return None
# get max possible chunk region
- max_possible_chunk_region = (min([i["region"][0] for i in possible_chunk_regions]),
- max([i["region"][1] for i in possible_chunk_regions]))
+ max_possible_chunk_region = (
+ min([i["region"][0] for i in possible_chunk_regions]),
+ max([i["region"][1] for i in possible_chunk_regions]),
+ )
# get mem for chunk region
regions_dict_list = []
@@ -149,15 +155,17 @@ def _select_min_memory_chunk_region(self, possible_chunk_regions, chunk_infos):
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region)
cur_chunk_infos = chunk_infos + [cur_region]
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
- cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1]
+ cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0] : max_possible_chunk_region[1] + 1]
cur_chunk_region_max_peak = max(cur_chunk_region_peak)
- regions_dict_list.append({
- "chunk_info": region,
- "chunk_max_mem": cur_chunk_region_max_peak,
- "chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]),
- "reorder_chunk_info": cur_region,
- "reorder_node_list": cur_node_list,
- })
+ regions_dict_list.append(
+ {
+ "chunk_info": region,
+ "chunk_max_mem": cur_chunk_region_max_peak,
+ "chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]),
+ "reorder_chunk_info": cur_region,
+ "reorder_node_list": cur_node_list,
+ }
+ )
# select the min mem
chunk_max_mem = [i["chunk_max_mem"] for i in regions_dict_list]
@@ -175,7 +183,9 @@ def _is_legal_region(self, cur_chunk_info, chunk_infos):
return False
for i in chunk_infos:
region = i["region"]
- if not ((chunk_region_start > region[1] and chunk_region_end > region[1]) or
- (chunk_region_start < region[0] and chunk_region_end < region[0])):
+ if not (
+ (chunk_region_start > region[1] and chunk_region_end > region[1])
+ or (chunk_region_start < region[0] and chunk_region_end < region[0])
+ ):
return False
return True
diff --git a/colossalai/autochunk/trace_flow.py b/colossalai/autochunk/trace_flow.py
index a1080fda1541..8b36c99bbadd 100644
--- a/colossalai/autochunk/trace_flow.py
+++ b/colossalai/autochunk/trace_flow.py
@@ -16,7 +16,6 @@
class TraceFlow(object):
-
def __init__(self, trace_indice: TraceIndice, node_mgr: NodeMgr) -> None:
self.trace_indice = trace_indice
self.node_mgr = node_mgr
@@ -151,7 +150,7 @@ def _assign_single_node_flow(
return True
def _get_all_node_info(self, end_dim, start_idx, end_idx):
- cur_node_list = [self.node_mgr.get_node_by_idx(end_idx)] # start from the last node
+ cur_node_list = [self.node_mgr.get_node_by_idx(end_idx)] # start from the last node
all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}}
while len(cur_node_list) > 0:
@@ -266,7 +265,7 @@ def _get_prepose_nodes(self, all_node_info: Dict, start_idx: int, end_idx: int,
maybe_prepose_nodes.sort(
key=lambda x: self.node_mgr.find_node_idx(x),
reverse=True,
- ) # from last node to first node
+ ) # from last node to first node
prepose_nodes = []
# set every node as root, search its args, if all legal, turn root and args as prepose nodes
while len(maybe_prepose_nodes) > 0:
@@ -328,7 +327,8 @@ def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx):
def flow_search(self, start_idx, start_dim, end_idx, end_dim):
inputs, outputs = find_chunk_compute_input_and_output_nodes(
- self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1))
+ self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1)
+ )
# get every node's chunk dim and fix dim
all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx)
@@ -371,8 +371,9 @@ def flow_search(self, start_idx, start_dim, end_idx, end_dim):
return chunk_info
- def _get_other_output_info(self, outputs: List[Node], start_idx: int, start_dim: int, end_idx: int, end_dim: int,
- chunk_info: Dict):
+ def _get_other_output_info(
+ self, outputs: List[Node], start_idx: int, start_dim: int, end_idx: int, end_dim: int, chunk_info: Dict
+ ):
start_node = self.node_mgr.get_node_by_idx(start_idx)
# loop all outputs
for output in outputs:
@@ -384,8 +385,8 @@ def _get_other_output_info(self, outputs: List[Node], start_idx: int, start_dim:
# skip non tensor
if get_node_shape(output) is None:
# log shape tensor
- if len(output.meta['fwd_out']) > 0 and isinstance(output.meta['fwd_out'][0], int):
- chunk_info["outputs_non_tensor"][output] = str(output.meta['fwd_out'])
+ if len(output.meta["fwd_out"]) > 0 and isinstance(output.meta["fwd_out"][0], int):
+ chunk_info["outputs_non_tensor"][output] = str(output.meta["fwd_out"])
continue
# loop every dim of outputs, try to find a legal one
for output_dim in range(len(get_node_shape(output))):
@@ -421,7 +422,8 @@ def _update_chunk_info(self, chunk_info: Dict, new_all_node_info: Dict, output:
for k, v in new_all_node_info.items():
if k in chunk_info["node_chunk_dim"]:
chunk_info["node_chunk_dim"][k]["fix_dim"] = list(
- set(chunk_info["node_chunk_dim"][k]["fix_dim"] + v["fix_dim"]))
+ set(chunk_info["node_chunk_dim"][k]["fix_dim"] + v["fix_dim"])
+ )
else:
chunk_info["node_chunk_dim"][k] = v
chunk_info["outputs"].append(output)
@@ -443,8 +445,11 @@ def _reassign_reshape_size(self, chunk_info):
if node.args[0] in chunk_info["inputs_non_chunk"]:
continue
reshape_args = flat_list(node.args[1:])
- if len(reshape_args) == 1 and get_node_shape(reshape_args[0]) is None and len(
- reshape_args[0].meta['fwd_out']) > 1:
+ if (
+ len(reshape_args) == 1
+ and get_node_shape(reshape_args[0]) is None
+ and len(reshape_args[0].meta["fwd_out"]) > 1
+ ):
continue
chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"]
new_shape = ""
@@ -462,16 +467,17 @@ def _reassign_reshape_size(self, chunk_info):
chunk_info["reshape_size"] = reshape_size
return chunk_info
- def check_region_start_end(self, start_node: Node, start_dim: int, start_idx: int, end_node: Node, end_dim: int,
- end_idx: int) -> bool:
+ def check_region_start_end(
+ self, start_node: Node, start_dim: int, start_idx: int, end_node: Node, end_dim: int, end_idx: int
+ ) -> bool:
"""
check if region start and end is legal
"""
# dim cannot be None
- if (get_node_shape(end_node) is None or get_node_shape(start_node) is None):
+ if get_node_shape(end_node) is None or get_node_shape(start_node) is None:
return False
# dim size cannot be 1
- if (get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1):
+ if get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1:
return False
# must have users
if len(end_node.users) == 0:
diff --git a/colossalai/autochunk/trace_indice.py b/colossalai/autochunk/trace_indice.py
index fbe0741b8827..378c54acf782 100644
--- a/colossalai/autochunk/trace_indice.py
+++ b/colossalai/autochunk/trace_indice.py
@@ -1,5 +1,5 @@
import copy
-from typing import Dict, List, Tuple
+from typing import Dict, List
from torch.fx.node import Node
@@ -412,7 +412,7 @@ def _assign_interpolate_indice(self, node: Node, node_idx: int) -> None:
node_idx (int)
"""
# get conv input
- assert node.kwargs['size'] is None
+ assert node.kwargs["size"] is None
assert len(get_node_shape(node)) == 4
# assign index
@@ -826,7 +826,7 @@ def _clear_trace(self, node_idx: int) -> None:
# clear compute
for dim_compute in trace["compute"]:
for i in range(len(dim_compute) - 1, -1, -1):
- if (dim_compute[i] < trace_barrier and dim_compute[i] not in active_nodes):
+ if dim_compute[i] < trace_barrier and dim_compute[i] not in active_nodes:
dim_compute.pop(i)
continue
# clear source
@@ -876,10 +876,24 @@ def trace_indice(self) -> None:
self._assign_matmul_indice(node, idx)
elif "softmax" == node_name:
self._assign_softmax_indice(node, idx)
- elif any(n == node_name for n in [
- "mul", "add", "sigmoid", "relu", "sub", "truediv", "pow", "dropout", "where", "tanh", "exp",
- "sin", "cos"
- ]):
+ elif any(
+ n == node_name
+ for n in [
+ "mul",
+ "add",
+ "sigmoid",
+ "relu",
+ "sub",
+ "truediv",
+ "pow",
+ "dropout",
+ "where",
+ "tanh",
+ "exp",
+ "sin",
+ "cos",
+ ]
+ ):
self._assign_elementwise_indice(node, idx)
elif "einsum" == node_name:
self._assign_einsum_indice(node, idx)
@@ -920,7 +934,7 @@ def trace_indice(self) -> None:
else:
raise NotImplementedError(node_name, "module not implemented yet!")
elif node.op == "get_attr":
- self._assign_all_indice(node, idx) # get param
+ self._assign_all_indice(node, idx) # get param
elif node.op == "output":
continue
else:
diff --git a/colossalai/autochunk/utils.py b/colossalai/autochunk/utils.py
index 064baa047155..f6f803a5ce0a 100644
--- a/colossalai/autochunk/utils.py
+++ b/colossalai/autochunk/utils.py
@@ -1,4 +1,4 @@
-from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
+from typing import Any, Dict, List, Union
from torch.fx.node import Node
@@ -10,7 +10,6 @@
class NodeMgr(object):
-
def __init__(self, nodes_list: List[Node]) -> None:
self._node_list = nodes_list
self._node_dict = {}
@@ -174,16 +173,22 @@ def find_chunk_compute_input_and_output_nodes(nodes: List[Node]) -> Union[List,
# we treat that input node as the input of the checkpoint function
for node in nodes:
for input_node in node._input_nodes.keys():
- if (input_node not in nodes and input_node not in input_nodes
- and not is_non_compute_node_except_placeholder(input_node)):
+ if (
+ input_node not in nodes
+ and input_node not in input_nodes
+ and not is_non_compute_node_except_placeholder(input_node)
+ ):
input_nodes.append(input_node)
# if a node has a user node which is not in the node list
# we treat that user node as the node receiving the current node output
for node in nodes:
for output_node in node.users.keys():
- if (output_node not in nodes and node not in output_nodes
- and not is_non_compute_node_except_placeholder_output(output_node)):
+ if (
+ output_node not in nodes
+ and node not in output_nodes
+ and not is_non_compute_node_except_placeholder_output(output_node)
+ ):
output_nodes.append(node)
return input_nodes, output_nodes
@@ -238,7 +243,10 @@ def find_tensor_shape_node(node_list: List[Node]) -> List[Node]:
for node in node_list:
if get_node_shape(node) is not None:
out.append(node)
- elif len(node.meta['fwd_out']) > 0 and isinstance(node.meta['fwd_out'], list) and isinstance(
- node.meta['fwd_out'][0], int):
+ elif (
+ len(node.meta["fwd_out"]) > 0
+ and isinstance(node.meta["fwd_out"], list)
+ and isinstance(node.meta["fwd_out"][0], int)
+ ):
out.append(node)
return out
diff --git a/colossalai/booster/accelerator.py b/colossalai/booster/accelerator.py
index fc2c4a40068b..92990907bc2e 100644
--- a/colossalai/booster/accelerator.py
+++ b/colossalai/booster/accelerator.py
@@ -1,12 +1,11 @@
import torch
import torch.nn as nn
-__all__ = ['Accelerator']
+__all__ = ["Accelerator"]
_supported_devices = [
- 'cpu',
- 'cuda',
-
+ "cpu",
+ "cuda",
# To be supported
# 'xpu',
# 'npu',
@@ -25,21 +24,22 @@ class Accelerator:
def __init__(self, device: str):
self.device = device
- assert self.device in _supported_devices, f"Device {self.device} is not supported yet, supported devices include {_supported_devices}"
+ assert (
+ self.device in _supported_devices
+ ), f"Device {self.device} is not supported yet, supported devices include {_supported_devices}"
def bind(self):
"""
Set the default device for the current process.
"""
- if self.device == 'cpu':
+ if self.device == "cpu":
pass
- elif self.device == 'cuda':
+ elif self.device == "cuda":
# TODO(FrankLeeeee): use global environment to check if it is a dist job
# if is_distributed:
# local_rank = EnvTable().get_local_rank()
# torch.cuda.set_device(torch.device(f'cuda:{local_rank}'))
- torch.cuda.set_device(torch.device('cuda'))
- pass
+ torch.cuda.set_device(torch.device("cuda"))
else:
raise ValueError(f"Device {self.device} is not supported yet")
diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py
index fb9dae7c9650..d73bc5babd80 100644
--- a/colossalai/booster/booster.py
+++ b/colossalai/booster/booster.py
@@ -8,6 +8,7 @@
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
+import colossalai.interface.pretrained as pretrained_utils
from colossalai.checkpoint_io import GeneralCheckpointIO
from colossalai.interface import ModelWrapper, OptimizerWrapper
@@ -16,7 +17,7 @@
from .plugin import Plugin
from .plugin.pp_plugin_base import PipelinePluginBase
-__all__ = ['Booster']
+__all__ = ["Booster"]
class Booster:
@@ -60,28 +61,31 @@ class Booster:
plugin (Plugin): The plugin to run the training. Default: None.
"""
- def __init__(self,
- device: Optional[str] = None,
- mixed_precision: Optional[Union[MixedPrecision, str]] = None,
- plugin: Optional[Plugin] = None) -> None:
+ def __init__(
+ self,
+ device: Optional[str] = None,
+ mixed_precision: Optional[Union[MixedPrecision, str]] = None,
+ plugin: Optional[Plugin] = None,
+ ) -> None:
if plugin is not None:
assert isinstance(
- plugin, Plugin), f'Expected the argument plugin to be an instance of Plugin, but got {type(plugin)}.'
+ plugin, Plugin
+ ), f"Expected the argument plugin to be an instance of Plugin, but got {type(plugin)}."
self.plugin = plugin
# set accelerator
if self.plugin and self.plugin.control_device():
self.accelerator = None
if device is not None:
- warnings.warn('The plugin will control the accelerator, so the device argument will be ignored.')
+ warnings.warn("The plugin will control the accelerator, so the device argument will be ignored.")
else:
- device = device or 'cuda'
+ device = device or "cuda"
self.accelerator = Accelerator(device)
# set precision
if self.plugin and self.plugin.control_precision():
if mixed_precision is not None:
- warnings.warn('The plugin will control the precision, so the mixed_precision argument will be ignored.')
+ warnings.warn("The plugin will control the precision, so the mixed_precision argument will be ignored.")
self.mixed_precision = None
elif mixed_precision is None:
self.mixed_precision = None
@@ -95,7 +99,7 @@ def __init__(self,
self.mixed_precision = mixed_precision
else:
raise ValueError(
- f'Expected the argument mixed_precision to be a string or an instance of Precision, but got {type(mixed_precision)}.'
+ f"Expected the argument mixed_precision to be a string or an instance of Precision, but got {type(mixed_precision)}."
)
if self.plugin is not None and self.plugin.control_checkpoint_io():
@@ -128,20 +132,28 @@ def boost(
"""
# TODO(FrankLeeeee): consider multi-model and multi-optimizer case
# TODO(FrankLeeeee): consider multi-dataloader case
+ pretrained_path = pretrained_utils.get_pretrained_path(model)
# transform model for mixed precision
if self.plugin:
model, optimizer, criterion, dataloader, lr_scheduler = self.plugin.configure(
- model, optimizer, criterion, dataloader, lr_scheduler)
+ model, optimizer, criterion, dataloader, lr_scheduler
+ )
if self.plugin and not self.plugin.control_device():
# transform model for accelerator
- model = self.accelerator.configure(model)
+ model = self.accelerator.configure_model(model)
if self.mixed_precision and (self.plugin is None or self.plugin and not self.plugin.control_precision()):
# transform model for mixed precision
# when mixed_precision is specified and the plugin is not given or does not control the precision
model, optimizer, criterion = self.mixed_precision.configure(model, optimizer, criterion)
+ if pretrained_path:
+ self.load_model(model, pretrained_path)
+ # clear pretrained path attr
+ orig_model = model.unwrap() if isinstance(model, ModelWrapper) else model
+ pretrained_utils.set_pretrained_path(orig_model, None)
+
return model, optimizer, criterion, dataloader, lr_scheduler
def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None:
@@ -154,13 +166,15 @@ def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None:
# TODO(frank lee): implement this method with plugin
optimizer.backward(loss)
- def execute_pipeline(self,
- data_iter: Iterator,
- model: nn.Module,
- criterion: Callable[[Any, Any], torch.Tensor],
- optimizer: Optional[Optimizer] = None,
- return_loss: bool = True,
- return_outputs: bool = False) -> Dict[str, Any]:
+ def execute_pipeline(
+ self,
+ data_iter: Iterator,
+ model: nn.Module,
+ criterion: Callable[[Any, Any], torch.Tensor],
+ optimizer: Optional[Optimizer] = None,
+ return_loss: bool = True,
+ return_outputs: bool = False,
+ ) -> Dict[str, Any]:
"""
Execute forward & backward when utilizing pipeline parallel.
Return loss or Huggingface style model outputs if needed.
@@ -185,8 +199,9 @@ def execute_pipeline(self,
ret_dict['loss'] is the loss of forward if return_loss is set to True, else None.
ret_dict['outputs'] is the Huggingface style model outputs during forward if return_output is set to True, else None.
"""
- assert isinstance(self.plugin,
- PipelinePluginBase), f'The plugin {self.plugin.__class__.__name__} does not support pipeline.'
+ assert isinstance(
+ self.plugin, PipelinePluginBase
+ ), f"The plugin {self.plugin.__class__.__name__} does not support pipeline."
return self.plugin.execute_pipeline(data_iter, model, criterion, optimizer, return_loss, return_outputs)
def no_sync(self, model: nn.Module = None, optimizer: OptimizerWrapper = None) -> contextmanager:
@@ -200,8 +215,10 @@ def no_sync(self, model: nn.Module = None, optimizer: OptimizerWrapper = None) -
Returns:
contextmanager: Context to disable gradient synchronization.
"""
- assert self.plugin is not None, f'no_sync is only enabled when a plugin is provided and the plugin supports no_sync.'
- assert self.plugin.support_no_sync(), f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
+ assert (
+ self.plugin is not None
+ ), f"no_sync is only enabled when a plugin is provided and the plugin supports no_sync."
+ assert self.plugin.support_no_sync(), f"The plugin {self.plugin.__class__.__name__} does not support no_sync."
return self.plugin.no_sync(model, optimizer)
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None:
@@ -217,14 +234,16 @@ def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, str
"""
self.checkpoint_io.load_model(model, checkpoint, strict)
- def save_model(self,
- model: Union[nn.Module, ModelWrapper],
- checkpoint: str,
- shard: bool = False,
- gather_dtensor: bool = True,
- prefix: Optional[str] = None,
- size_per_shard: int = 1024,
- use_safetensors: bool = False) -> None:
+ def save_model(
+ self,
+ model: Union[nn.Module, ModelWrapper],
+ checkpoint: str,
+ shard: bool = False,
+ gather_dtensor: bool = True,
+ prefix: Optional[str] = None,
+ size_per_shard: int = 1024,
+ use_safetensors: bool = False,
+ ) -> None:
"""Save model to checkpoint.
Args:
@@ -239,13 +258,15 @@ def save_model(self,
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
use_safetensors (bool, optional): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved.
"""
- self.checkpoint_io.save_model(model,
- checkpoint=checkpoint,
- shard=shard,
- gather_dtensor=gather_dtensor,
- prefix=prefix,
- size_per_shard=size_per_shard,
- use_safetensors=use_safetensors)
+ self.checkpoint_io.save_model(
+ model,
+ checkpoint=checkpoint,
+ shard=shard,
+ gather_dtensor=gather_dtensor,
+ prefix=prefix,
+ size_per_shard=size_per_shard,
+ use_safetensors=use_safetensors,
+ )
def load_optimizer(self, optimizer: Optimizer, checkpoint: str) -> None:
"""Load optimizer from checkpoint.
@@ -260,13 +281,15 @@ def load_optimizer(self, optimizer: Optimizer, checkpoint: str) -> None:
"""
self.checkpoint_io.load_optimizer(optimizer, checkpoint)
- def save_optimizer(self,
- optimizer: Optimizer,
- checkpoint: str,
- shard: bool = False,
- gather_dtensor: bool = True,
- prefix: Optional[str] = None,
- size_per_shard: int = 1024) -> None:
+ def save_optimizer(
+ self,
+ optimizer: Optimizer,
+ checkpoint: str,
+ shard: bool = False,
+ gather_dtensor: bool = True,
+ prefix: Optional[str] = None,
+ size_per_shard: int = 1024,
+ ) -> None:
"""
Save optimizer to checkpoint.
diff --git a/colossalai/booster/mixed_precision/__init__.py b/colossalai/booster/mixed_precision/__init__.py
index 0df9d84159f9..68c6221ec809 100644
--- a/colossalai/booster/mixed_precision/__init__.py
+++ b/colossalai/booster/mixed_precision/__init__.py
@@ -6,16 +6,22 @@
from .mixed_precision_base import MixedPrecision
__all__ = [
- 'MixedPrecision', 'mixed_precision_factory', 'FP16_Apex_MixedPrecision', 'FP16_Torch_MixedPrecision',
- 'FP32_MixedPrecision', 'BF16_MixedPrecision', 'FP8_MixedPrecision', 'FP16NaiveMixedPrecision'
+ "MixedPrecision",
+ "mixed_precision_factory",
+ "FP16_Apex_MixedPrecision",
+ "FP16_Torch_MixedPrecision",
+ "FP32_MixedPrecision",
+ "BF16_MixedPrecision",
+ "FP8_MixedPrecision",
+ "FP16NaiveMixedPrecision",
]
_mixed_precision_mapping = {
- 'fp16': FP16TorchMixedPrecision,
- 'fp16_apex': FP16ApexMixedPrecision,
- 'fp16_naive': FP16NaiveMixedPrecision,
- 'bf16': BF16MixedPrecision,
- 'fp8': FP8MixedPrecision
+ "fp16": FP16TorchMixedPrecision,
+ "fp16_apex": FP16ApexMixedPrecision,
+ "fp16_naive": FP16NaiveMixedPrecision,
+ "bf16": BF16MixedPrecision,
+ "fp8": FP8MixedPrecision,
}
@@ -31,5 +37,5 @@ def mixed_precision_factory(mixed_precision_type: str) -> MixedPrecision:
return _mixed_precision_mapping[mixed_precision_type]()
else:
raise ValueError(
- f'Mixed precision type {mixed_precision_type} is not supported, support types include {list(_mixed_precision_mapping.keys())}'
+ f"Mixed precision type {mixed_precision_type} is not supported, support types include {list(_mixed_precision_mapping.keys())}"
)
diff --git a/colossalai/booster/mixed_precision/fp16_apex.py b/colossalai/booster/mixed_precision/fp16_apex.py
index e184271e932a..2fa7b54cdd30 100644
--- a/colossalai/booster/mixed_precision/fp16_apex.py
+++ b/colossalai/booster/mixed_precision/fp16_apex.py
@@ -23,16 +23,18 @@ class FP16ApexMixedPrecision(MixedPrecision):
max_loss_scale(float, default=2.**24 ): Sets a ceiling for the loss scale values that can be chosen by dynamic loss scaling. If dynamic loss scaling is not used, max_loss_scale is ignored.
"""
- def __init__(self,
- opt_level: Optional[str] = "O1",
- cast_model_type: torch.dtype = None,
- patch_torch_functions: bool = None,
- keep_batchnorm_fp32: Union[bool, str] = None,
- master_weights: bool = None,
- loss_scale: Union[float, str] = None,
- cast_model_outputs: Any = None,
- num_losses: Optional[int] = 1,
- verbosity: int = 1,
- min_loss_scale: float = None,
- max_loss_scale: float = 2.**24) -> None:
+ def __init__(
+ self,
+ opt_level: Optional[str] = "O1",
+ cast_model_type: torch.dtype = None,
+ patch_torch_functions: bool = None,
+ keep_batchnorm_fp32: Union[bool, str] = None,
+ master_weights: bool = None,
+ loss_scale: Union[float, str] = None,
+ cast_model_outputs: Any = None,
+ num_losses: Optional[int] = 1,
+ verbosity: int = 1,
+ min_loss_scale: float = None,
+ max_loss_scale: float = 2.0**24,
+ ) -> None:
pass
diff --git a/colossalai/booster/mixed_precision/fp16_naive.py b/colossalai/booster/mixed_precision/fp16_naive.py
index 5d0d815257f3..e5624a9d7477 100644
--- a/colossalai/booster/mixed_precision/fp16_naive.py
+++ b/colossalai/booster/mixed_precision/fp16_naive.py
@@ -15,12 +15,14 @@ class FP16NaiveMixedPrecision(MixedPrecision):
verbose(bool): if set to `True`, will print debug info.
"""
- def __init__(self,
- log_num_zeros_in_grad: bool,
- initial_scale: int,
- growth_factor: int,
- backoff_factor: float,
- hysteresis: int,
- max_scale: int,
- verbose: bool = None) -> None:
+ def __init__(
+ self,
+ log_num_zeros_in_grad: bool,
+ initial_scale: int,
+ growth_factor: int,
+ backoff_factor: float,
+ hysteresis: int,
+ max_scale: int,
+ verbose: bool = None,
+ ) -> None:
pass
diff --git a/colossalai/booster/mixed_precision/fp16_torch.py b/colossalai/booster/mixed_precision/fp16_torch.py
index 26fd92bd50b8..7dce6e6da33e 100644
--- a/colossalai/booster/mixed_precision/fp16_torch.py
+++ b/colossalai/booster/mixed_precision/fp16_torch.py
@@ -1,4 +1,4 @@
-from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+from typing import Callable, Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -9,7 +9,7 @@
from .mixed_precision_base import MixedPrecision
-__all__ = ['FP16_Torch_MixedPrecision', 'TorchAMPOptimizer', 'TorchAMPModule']
+__all__ = ["FP16_Torch_MixedPrecision", "TorchAMPOptimizer", "TorchAMPModule"]
class TorchAMPOptimizer(OptimizerWrapper):
@@ -29,17 +29,21 @@ class TorchAMPOptimizer(OptimizerWrapper):
calls that may cause the scale to increase. Default: 2000.
"""
- def __init__(self,
- optim: Optimizer,
- init_scale: float = 2.**16,
- growth_factor: float = 2.0,
- backoff_factor: float = 0.5,
- growth_interval: int = 2000) -> None:
+ def __init__(
+ self,
+ optim: Optimizer,
+ init_scale: float = 2.0**16,
+ growth_factor: float = 2.0,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 2000,
+ ) -> None:
super().__init__(optim)
- self.scaler = torch.cuda.amp.GradScaler(init_scale=init_scale,
- growth_factor=growth_factor,
- backoff_factor=backoff_factor,
- growth_interval=growth_interval)
+ self.scaler = torch.cuda.amp.GradScaler(
+ init_scale=init_scale,
+ growth_factor=growth_factor,
+ backoff_factor=backoff_factor,
+ growth_interval=growth_interval,
+ )
def backward(self, loss: Tensor, *args, **kwargs) -> None:
scaled_loss = self.scale_loss(loss)
@@ -60,12 +64,14 @@ def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
self.unscale_grad()
super().clip_grad_by_value(clip_value, *args, **kwargs)
- def clip_grad_by_norm(self,
- max_norm: Union[float, int],
- norm_type: Union[float, int] = 2.0,
- error_if_nonfinite: bool = False,
- *args,
- **kwargs) -> None:
+ def clip_grad_by_norm(
+ self,
+ max_norm: Union[float, int],
+ norm_type: Union[float, int] = 2.0,
+ error_if_nonfinite: bool = False,
+ *args,
+ **kwargs,
+ ) -> None:
self.unscale_grad()
super().clip_grad_by_norm(max_norm, norm_type, error_if_nonfinite, *args, **kwargs)
@@ -102,22 +108,27 @@ class FP16TorchMixedPrecision(MixedPrecision):
calls that may cause the scale to increase. Default: 2000.
"""
- def __init__(self,
- init_scale: float = 2.**16,
- growth_factor: float = 2.0,
- backoff_factor: float = 0.5,
- growth_interval: int = 2000) -> None:
+ def __init__(
+ self,
+ init_scale: float = 2.0**16,
+ growth_factor: float = 2.0,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 2000,
+ ) -> None:
super().__init__()
- self.torch_amp_kwargs = dict(init_scale=init_scale,
- growth_factor=growth_factor,
- backoff_factor=backoff_factor,
- growth_interval=growth_interval)
-
- def configure(self,
- model: nn.Module,
- optimizer: Optional[Optimizer] = None,
- criterion: Optional[Callable] = None,
- ) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
+ self.torch_amp_kwargs = dict(
+ init_scale=init_scale,
+ growth_factor=growth_factor,
+ backoff_factor=backoff_factor,
+ growth_interval=growth_interval,
+ )
+
+ def configure(
+ self,
+ model: nn.Module,
+ optimizer: Optional[Optimizer] = None,
+ criterion: Optional[Callable] = None,
+ ) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
model = TorchAMPModule(model)
if optimizer is not None:
optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs)
diff --git a/colossalai/booster/plugin/__init__.py b/colossalai/booster/plugin/__init__.py
index f48bf38bd724..62f3708fc629 100644
--- a/colossalai/booster/plugin/__init__.py
+++ b/colossalai/booster/plugin/__init__.py
@@ -4,11 +4,12 @@
from .plugin_base import Plugin
from .torch_ddp_plugin import TorchDDPPlugin
-__all__ = ['Plugin', 'TorchDDPPlugin', 'GeminiPlugin', 'LowLevelZeroPlugin', 'HybridParallelPlugin']
+__all__ = ["Plugin", "TorchDDPPlugin", "GeminiPlugin", "LowLevelZeroPlugin", "HybridParallelPlugin"]
import torch
from packaging import version
-if version.parse(torch.__version__) >= version.parse('1.12.0'):
+if version.parse(torch.__version__) >= version.parse("1.12.0"):
from .torch_fsdp_plugin import TorchFSDPPlugin
- __all__.append('TorchFSDPPlugin')
+
+ __all__.append("TorchFSDPPlugin")
diff --git a/colossalai/booster/plugin/dp_plugin_base.py b/colossalai/booster/plugin/dp_plugin_base.py
index d5da5938bfd9..d2dd00453e32 100644
--- a/colossalai/booster/plugin/dp_plugin_base.py
+++ b/colossalai/booster/plugin/dp_plugin_base.py
@@ -10,25 +10,19 @@
class DPPluginBase(Plugin):
- """This is a base class for all DP plugins. It sets up world size and rank, and provides data loader creation.
- """
+ """This is a base class for all DP plugins. It sets up world size and rank, and provides data loader creation."""
def __init__(self) -> None:
super().__init__()
- assert dist.is_initialized(
- ), 'torch.distributed is not initialized, please use colossalai.launch to create the distributed environment'
+ assert (
+ dist.is_initialized()
+ ), "torch.distributed is not initialized, please use colossalai.launch to create the distributed environment"
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
- def prepare_dataloader(self,
- dataset,
- batch_size,
- shuffle=False,
- seed=1024,
- drop_last=False,
- pin_memory=False,
- num_workers=0,
- **kwargs):
+ def prepare_dataloader(
+ self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
+ ):
r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
`torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.
@@ -60,11 +54,13 @@ def seed_worker(worker_id):
torch.manual_seed(worker_seed)
random.seed(worker_seed)
- return DataLoader(dataset,
- batch_size=batch_size,
- sampler=sampler,
- worker_init_fn=seed_worker,
- drop_last=drop_last,
- pin_memory=pin_memory,
- num_workers=num_workers,
- **_kwargs)
+ return DataLoader(
+ dataset,
+ batch_size=batch_size,
+ sampler=sampler,
+ worker_init_fn=seed_worker,
+ drop_last=drop_last,
+ pin_memory=pin_memory,
+ num_workers=num_workers,
+ **_kwargs,
+ )
diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py
index de03ba27bfda..ca722a0768dc 100644
--- a/colossalai/booster/plugin/gemini_plugin.py
+++ b/colossalai/booster/plugin/gemini_plugin.py
@@ -27,14 +27,13 @@
from .dp_plugin_base import DPPluginBase
-__all__ = ['GeminiPlugin']
+__all__ = ["GeminiPlugin"]
-SUPPORTED_PRECISION = ['fp16', 'bf16']
-PRECISION_STR_TO_DTYPE = {'fp16': torch.half, 'bf16': torch.bfloat16}
+SUPPORTED_PRECISION = ["fp16", "bf16"]
+PRECISION_STR_TO_DTYPE = {"fp16": torch.half, "bf16": torch.bfloat16}
class GeminiCheckpointIO(GeneralCheckpointIO):
-
def __init__(self) -> None:
super().__init__()
self.coordinator = DistCoordinator()
@@ -45,6 +44,7 @@ def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
As there is communication when getting state dict, model.state_dict() must be called on all processes.
"""
+ assert isinstance(model, GeminiDDP), "Please boost the model before saving!"
state_dict = model.state_dict(only_rank_0=True)
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors)
@@ -54,37 +54,43 @@ def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool =
Load model from checkpoint with automatic unwrapping.
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
"""
+ assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
super().load_unsharded_model(model, checkpoint, strict=strict)
- def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
+ def save_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool):
"""
Save unsharded optimizer state dict to checkpoint.
After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank.
As there is communication when getting state dict, optimizer.state_dict() must be called on all processes.
The saving process will only be executed by master rank.
"""
+ assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!"
state_dict = optimizer.state_dict()
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors=False)
- def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str):
+ def load_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str):
"""
Loading unsharded optimizer from checkpoint file.
For each process, only loading optimizer states of parameters it controls.
"""
+ assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!"
super().load_unsharded_optimizer(optimizer, checkpoint)
- def save_sharded_model(self,
- model: GeminiDDP,
- checkpoint_path: str,
- gather_dtensor: bool = False,
- prefix: Optional[str] = None,
- max_shard_size: int = 1024,
- use_safetensors: bool = False):
+ def save_sharded_model(
+ self,
+ model: GeminiDDP,
+ checkpoint_path: str,
+ gather_dtensor: bool = False,
+ prefix: Optional[str] = None,
+ max_shard_size: int = 1024,
+ use_safetensors: bool = False,
+ ):
"""
Save sharded model.
As there is communication when getting state dict, model.state_dict() must be called on all processes.
"""
+ assert isinstance(model, GeminiDDP), "Please boost the model before saving!"
if os.path.isfile(checkpoint_path):
logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
return
@@ -97,40 +103,43 @@ def save_sharded_model(self,
# Save shards of optimizer states.
is_master = self.coordinator.is_master()
- total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
- checkpoint=checkpoint_path,
- index_file=index_file,
- base_filename=weights_name,
- is_master=is_master,
- use_safetensors=use_safetensors)
+ total_size = save_state_dict_shards(
+ sharded_state_dict=state_dict_shard,
+ checkpoint=checkpoint_path,
+ index_file=index_file,
+ base_filename=weights_name,
+ is_master=is_master,
+ use_safetensors=use_safetensors,
+ )
# only save the index file on the master rank
if self.coordinator.is_master():
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
- save_config_file(model.module, checkpoint_path)
- logging.info(f"The model is split into checkpoint shards. "
- f"You can find where each parameters has been saved in the "
- f"index located at {save_index_file}.")
-
- def load_sharded_model(self,
- model: GeminiDDP,
- checkpoint_index_file: Path,
- strict: bool = False,
- use_safetensors: bool = False):
+ save_config_file(model.unwrap(), checkpoint_path)
+ logging.info(
+ f"The model is split into checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {save_index_file}."
+ )
+
+ def load_sharded_model(
+ self, model: GeminiDDP, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False
+ ):
"""
Load shard model, load model from multiple files.
"""
+ assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)
- def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str,
- size_per_shard: int):
+ def save_sharded_optimizer(
+ self, optimizer: GeminiOptimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
+ ):
"""
Save sharded optimizer state dict to checkpoint folder.
As there is communication when getting state dict, this must be called on all processes.
"""
-
- assert isinstance(optimizer, GeminiOptimizer)
+ assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!"
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
@@ -153,27 +162,31 @@ def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_
# Save shards of optimizer states.
is_master = self.coordinator.is_master()
- total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
- checkpoint=checkpoint,
- index_file=index_file,
- base_filename=states_name,
- is_master=is_master,
- use_safetensors=False)
+ total_size = save_state_dict_shards(
+ sharded_state_dict=state_dict_shard,
+ checkpoint=checkpoint,
+ index_file=index_file,
+ base_filename=states_name,
+ is_master=is_master,
+ use_safetensors=False,
+ )
# Wrap up index file. Only save it on master rank.
if self.coordinator.is_master():
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
- logging.info(f"The optimizer is going to be split to checkpoint shards. "
- f"You can find where each parameters has been saved in the "
- f"index located at {save_index_file}.")
+ logging.info(
+ f"The optimizer is going to be split to checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {save_index_file}."
+ )
- def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint_index_file: Path, prefix: str):
+ def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_file: Path, prefix: str):
"""
Loading sharded optimizer from checkpoint folder, with index file given.
For each process, only loading optimizer states of parameters it controls.
"""
-
+ assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!"
if not os.path.isfile(checkpoint_index_file):
logging.error(f"Provided path ({checkpoint_index_file}) should be a file")
@@ -185,8 +198,10 @@ def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint_index_file: Pa
# Load param_groups.
param_group_path = ckpt_index_file.get_param_group_filename()
if param_group_path is None:
- raise RuntimeError(f'Invalid index file path {checkpoint_index_file} for an optimizer. \
- Lacking param group file under current directory.')
+ raise RuntimeError(
+ f"Invalid index file path {checkpoint_index_file} for an optimizer. \
+ Lacking param group file under current directory."
+ )
saved_param_groups = torch.load(param_group_path)
optimizer.load_param_groups(saved_param_groups)
@@ -214,16 +229,17 @@ class GeminiPlugin(DPPluginBase):
"""
Plugin for Gemini.
- Example:
- >>> from colossalai.booster import Booster
- >>> from colossalai.booster.plugin import GeminiPlugin
- >>>
- >>> model, train_dataset, optimizer, criterion = ...
- >>> plugin = GeminiPlugin()
+ ```python
+ from colossalai.booster import Booster
+ from colossalai.booster.plugin import GeminiPlugin
+
+ model, train_dataset, optimizer, criterion = ...
+ plugin = GeminiPlugin()
- >>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
- >>> booster = Booster(plugin=plugin)
- >>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
+ train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
+ booster = Booster(plugin=plugin)
+ model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
+ ```
Args:
chunk_config_dict (dict, optional): chunk configuration dictionary.
@@ -274,11 +290,11 @@ def __init__(
chunk_config_dict: Optional[dict] = None,
chunk_init_device: Optional[torch.device] = None,
placement_policy: str = "static",
- shard_param_frac: float = 1.0, # only for static placement
- offload_optim_frac: float = 0.0, # only for static placement
- offload_param_frac: float = 0.0, # only for static placement
- warmup_non_model_data_ratio: float = 0.8, # only for auto placement
- steady_cuda_cap_ratio: float = 0.9, # only for auto placement
+ shard_param_frac: float = 1.0, # only for static placement
+ offload_optim_frac: float = 0.0, # only for static placement
+ offload_param_frac: float = 0.0, # only for static placement
+ warmup_non_model_data_ratio: float = 0.8, # only for auto placement
+ steady_cuda_cap_ratio: float = 0.9, # only for auto placement
precision: str = "fp16",
pin_memory: bool = False,
force_outputs_fp32: bool = False,
@@ -300,7 +316,7 @@ def __init__(
verbose: bool = False,
) -> None:
super().__init__()
- assert precision in SUPPORTED_PRECISION, f'precision {precision} is not supported'
+ assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
self.gemini_config = dict(
chunk_config_dict=chunk_config_dict,
chunk_init_device=(chunk_init_device or get_current_device()),
@@ -319,16 +335,20 @@ def __init__(
memstats=memstats,
mixed_precision=PRECISION_STR_TO_DTYPE[precision],
)
- self.zero_optim_config = dict(gpu_margin_mem_ratio=gpu_margin_mem_ratio,)
- self.optim_kwargs = dict(initial_scale=initial_scale,
- growth_factor=growth_factor,
- backoff_factor=backoff_factor,
- growth_interval=growth_interval,
- hysteresis=hysteresis,
- min_scale=min_scale,
- max_scale=max_scale,
- max_norm=max_norm,
- norm_type=norm_type)
+ self.zero_optim_config = dict(
+ gpu_margin_mem_ratio=gpu_margin_mem_ratio,
+ )
+ self.optim_kwargs = dict(
+ initial_scale=initial_scale,
+ growth_factor=growth_factor,
+ backoff_factor=backoff_factor,
+ growth_interval=growth_interval,
+ hysteresis=hysteresis,
+ min_scale=min_scale,
+ max_scale=max_scale,
+ max_norm=max_norm,
+ norm_type=norm_type,
+ )
self.verbose = verbose
def support_no_sync(self) -> bool:
@@ -344,7 +364,7 @@ def control_device(self) -> bool:
return True
def supported_devices(self) -> List[str]:
- return ['cuda']
+ return ["cuda"]
def configure(
self,
@@ -354,7 +374,6 @@ def configure(
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
-
if not isinstance(model, ModelWrapper):
# convert model to sync bn
# FIXME(ver217): gemini does not support sync bn
@@ -368,13 +387,10 @@ def configure(
# wrap the model with Gemini
model = GeminiDDP(model, **self.gemini_config, verbose=self.verbose)
- if optimizer is not None and \
- not isinstance(optimizer, OptimizerWrapper):
- optimizer = GeminiOptimizer(optimizer,
- model.unwrap(),
- **self.zero_optim_config,
- **self.optim_kwargs,
- verbose=self.verbose)
+ if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
+ optimizer = GeminiOptimizer(
+ optimizer, model, **self.zero_optim_config, **self.optim_kwargs, verbose=self.verbose
+ )
return model, optimizer, criterion, dataloader, lr_scheduler
diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py
index fc04f3ecd8e7..479ccc3eb36e 100644
--- a/colossalai/booster/plugin/hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py
@@ -1,6 +1,7 @@
import random
from contextlib import nullcontext
from functools import partial
+from types import MethodType
from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union
import numpy as np
@@ -22,6 +23,7 @@
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
+from colossalai.shardformer.policies.base_policy import Policy
from colossalai.zero.low_level import LowLevelZeroOptimizer
from .pp_plugin_base import PipelinePluginBase
@@ -36,28 +38,37 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
class HybridParallelModule(ModelWrapper):
-
- def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup, use_ddp: bool,
- ddp_config: dict) -> None:
-
+ def __init__(
+ self,
+ module: Module,
+ precision: str,
+ shard_config: ShardConfig,
+ dp_group: ProcessGroup,
+ use_ddp: bool,
+ ddp_config: dict,
+ custom_policy: Policy,
+ ) -> None:
self.stage_manager = shard_config.pipeline_stage_manager
self.dp_group = dp_group
shardformer = ShardFormer(shard_config)
- module, self.shared_params = shardformer.optimize(module)
+ if custom_policy is not None:
+ assert isinstance(custom_policy, object)
+ module, self.shared_params = shardformer.optimize(module, policy=custom_policy)
# setting process groups for shared parameters
self.shared_param_process_groups = []
for shared_param in self.shared_params:
if len(shared_param) > 0:
self.shared_param_process_groups.append(
- self.stage_manager.init_process_group_by_stages(list(shared_param.keys())))
+ self.stage_manager.init_process_group_by_stages(list(shared_param.keys()))
+ )
# setting mixed_precision
self.mixed_precision = None
- if precision == 'fp16':
+ if precision == "fp16":
self.mixed_precision = torch.float16
- elif precision == 'bf16':
+ elif precision == "bf16":
self.mixed_precision = torch.bfloat16
if self.mixed_precision is not None:
module = module.to(self.mixed_precision)
@@ -120,22 +131,21 @@ def get_param_info(optim: Optimizer):
if optim is None:
return {}
- param_info = {'param_groups': [], 'param2id': {}, 'id2param': {}, 'param2shape': {}}
+ param_info = {"param_groups": [], "param2id": {}, "id2param": {}, "param2shape": {}}
start_index = 0
for group in optim.param_groups:
+ packed_group = {k: v for k, v in group.items() if k != "params"}
+ packed_group["params"] = []
- packed_group = {k: v for k, v in group.items() if k != 'params'}
- packed_group['params'] = []
-
- for param_id, param in enumerate(group['params'], start_index):
+ for param_id, param in enumerate(group["params"], start_index):
original_shape = param.shape if isinstance(param, torch.Tensor) else None
- packed_group['params'].append(param_id)
- param_info['param2id'][id(param)] = param_id
- param_info['id2param'][param_id] = id(param)
- param_info['param2shape'][id(param)] = original_shape
+ packed_group["params"].append(param_id)
+ param_info["param2id"][id(param)] = param_id
+ param_info["id2param"][param_id] = id(param)
+ param_info["param2shape"][id(param)] = original_shape
- param_info['param_groups'].append(packed_group)
- start_index += len(group['params'])
+ param_info["param_groups"].append(packed_group)
+ start_index += len(group["params"])
return param_info
@@ -144,75 +154,110 @@ def init_pipeline_optimizer(optim: Optimizer, model: Module):
model_params = set(model.parameters())
new_param_groups = []
for group in optim.param_groups:
- params = [p for p in group['params'] if p in model_params]
- new_param_groups.append({**group, 'params': params})
- optim.__setstate__({'param_groups': new_param_groups})
+ params = [p for p in group["params"] if p in model_params]
+ new_param_groups.append({**group, "params": params})
+ optim.__setstate__({"param_groups": new_param_groups})
class HybridParallelNaiveOptimizer(OptimizerWrapper):
-
def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool, param_info: OrderedDict):
self.param_info = param_info
if use_pipeline:
init_pipeline_optimizer(optim, model)
super().__init__(optim)
+ def update_master_params(self, model: Module):
+ pass
-class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
+ def get_working_to_master_map(self):
+ return None
- def __init__(self,
- optim: Optimizer,
- model: Module,
- use_pipeline: bool,
- param_info: OrderedDict,
- precision: str = 'fp16',
- initial_scale: float = 2**16,
- min_scale: float = 1,
- growth_factor: float = 2,
- backoff_factor: float = 0.5,
- growth_interval: int = 1000,
- hysteresis: int = 2,
- max_scale: float = 2**32,
- max_norm: float = 0):
+ def get_master_to_working_map(self):
+ return None
+
+
+class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
+ def __init__(
+ self,
+ optim: Optimizer,
+ model: Module,
+ use_pipeline: bool,
+ param_info: OrderedDict,
+ precision: str = "fp16",
+ initial_scale: float = 2**16,
+ min_scale: float = 1,
+ growth_factor: float = 2,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 1000,
+ hysteresis: int = 2,
+ max_scale: float = 2**32,
+ max_norm: float = 0,
+ ):
self.param_info = param_info
if use_pipeline:
init_pipeline_optimizer(optim, model)
- super().__init__(optim, precision, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval,
- hysteresis, max_scale, max_norm)
+ super().__init__(
+ optim,
+ precision,
+ initial_scale,
+ min_scale,
+ growth_factor,
+ backoff_factor,
+ growth_interval,
+ hysteresis,
+ max_scale,
+ max_norm,
+ )
class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
-
def __init__(
- self,
- optimizer: Optimizer,
- model: Module,
- use_pipeline: bool,
- param_info: OrderedDict,
- initial_scale: int = 2**16, # grad scaler config
- min_scale: int = 1,
- growth_factor: float = 2.,
- backoff_factor: float = .5,
- growth_interval: int = 2000,
- hysteresis: int = 2,
- max_scale: int = 2**24,
- clip_grad_norm: float = 0.0, # grad clipping
- verbose: bool = False,
- reduce_bucket_size: int = 1024 * 1024, # communication
- communication_dtype: Optional[torch.dtype] = None,
- overlap_communication: bool = True,
- partition_grad: bool = False, # stage 2 flag
- cpu_offload: bool = False, # cpu offload
- dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
- tp_process_group: Optional[ProcessGroup] = None, # if using tp
- forced_dtype: Optional[torch.dtype] = None):
+ self,
+ optimizer: Optimizer,
+ model: Module,
+ use_pipeline: bool,
+ param_info: OrderedDict,
+ initial_scale: int = 2**16, # grad scaler config
+ min_scale: int = 1,
+ growth_factor: float = 2.0,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 2000,
+ hysteresis: int = 2,
+ max_scale: int = 2**24,
+ clip_grad_norm: float = 0.0, # grad clipping
+ verbose: bool = False,
+ reduce_bucket_size: int = 1024 * 1024, # communication
+ communication_dtype: Optional[torch.dtype] = None,
+ overlap_communication: bool = True,
+ partition_grad: bool = False, # stage 2 flag
+ cpu_offload: bool = False, # cpu offload
+ dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
+ tp_process_group: Optional[ProcessGroup] = None, # if using tp
+ forced_dtype: Optional[torch.dtype] = None,
+ ):
self.param_info = param_info
if use_pipeline:
init_pipeline_optimizer(optimizer, model)
- super().__init__(optimizer, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval,
- hysteresis, max_scale, clip_grad_norm, verbose, reduce_bucket_size, communication_dtype,
- overlap_communication, partition_grad, cpu_offload, dp_process_group, tp_process_group,
- forced_dtype)
+ super().__init__(
+ optimizer,
+ initial_scale,
+ min_scale,
+ growth_factor,
+ backoff_factor,
+ growth_interval,
+ hysteresis,
+ max_scale,
+ clip_grad_norm,
+ verbose,
+ reduce_bucket_size,
+ communication_dtype,
+ overlap_communication,
+ partition_grad,
+ cpu_offload,
+ dp_process_group,
+ tp_process_group,
+ forced_dtype,
+ )
class HybridParallelPlugin(PipelinePluginBase):
@@ -221,16 +266,17 @@ class HybridParallelPlugin(PipelinePluginBase):
Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin.
The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size).
- Example:
- >>> from colossalai.booster import Booster
- >>> from colossalai.booster.plugin import HybridParallelPlugin
+ ```python
+ from colossalai.booster import Booster
+ from colossalai.booster.plugin import HybridParallelPlugin
- >>> model, train_dataset, optimizer, criterion = ...
- >>> plugin = HybridParallelPlugin(tp_size=2, pp_size=2)
+ model, train_dataset, optimizer, criterion = ...
+ plugin = HybridParallelPlugin(tp_size=2, pp_size=2)
- >>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
- >>> booster = Booster(plugin=plugin)
- >>> model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader)
+ train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
+ booster = Booster(plugin=plugin)
+ model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader)
+ ```
Args:
tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
@@ -243,9 +289,11 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer.
Currently all the optimization methods include fused normalization, flash attention and JIT.
Defaults to False.
- enable_fused_normalization (bool, optional): Whether to switch on fused normalization. Defaults to False.
- enable_flash_attention (bool, optional): Whether to switch on flash attention. Defaults to False.
- enable_jit_fused (bool, optional): Whether to switch on JIT. Default to Falase.
+ enable_fused_normalization (bool, optional): Whether to switch on fused normalization in Shardformer. Defaults to False.
+ enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False.
+ enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
+ enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
+ enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline.
@@ -268,47 +316,50 @@ class HybridParallelPlugin(PipelinePluginBase):
cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False.
communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None.
overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True.
+ custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None.
"""
- def __init__(self,
- tp_size: int,
- pp_size: int,
- precision: str = 'fp16',
- zero_stage: int = 0,
- enable_all_optimization: bool = False,
- enable_fused_normalization: bool = False,
- enable_flash_attention: bool = False,
- enable_jit_fused: bool = False,
- enable_sequence_parallelism: bool = False,
- enable_sequence_overlap: bool = False,
- num_microbatches: Optional[int] = None,
- microbatch_size: Optional[int] = None,
- initial_scale: float = 2**16,
- min_scale: float = 1,
- growth_factor: float = 2,
- backoff_factor: float = 0.5,
- growth_interval: int = 1000,
- hysteresis: int = 2,
- max_scale: float = 2**32,
- max_norm: float = 0,
- broadcast_buffers: bool = True,
- ddp_bucket_cap_mb: int = 25,
- find_unused_parameters: bool = False,
- check_reduction: bool = False,
- gradient_as_bucket_view: bool = False,
- static_graph: bool = False,
- zero_bucket_size_in_m: int = 12,
- cpu_offload: bool = False,
- communication_dtype: Optional[torch.dtype] = None,
- overlap_communication: bool = True) -> None:
-
+ def __init__(
+ self,
+ tp_size: int,
+ pp_size: int,
+ precision: str = "fp16",
+ zero_stage: int = 0,
+ enable_all_optimization: bool = False,
+ enable_fused_normalization: bool = False,
+ enable_flash_attention: bool = False,
+ enable_jit_fused: bool = False,
+ enable_sequence_parallelism: bool = False,
+ enable_sequence_overlap: bool = False,
+ num_microbatches: Optional[int] = None,
+ microbatch_size: Optional[int] = None,
+ initial_scale: float = 2**16,
+ min_scale: float = 1,
+ growth_factor: float = 2,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 1000,
+ hysteresis: int = 2,
+ max_scale: float = 2**32,
+ max_norm: float = 0,
+ broadcast_buffers: bool = True,
+ ddp_bucket_cap_mb: int = 25,
+ find_unused_parameters: bool = False,
+ check_reduction: bool = False,
+ gradient_as_bucket_view: bool = False,
+ static_graph: bool = False,
+ zero_bucket_size_in_m: int = 12,
+ cpu_offload: bool = False,
+ communication_dtype: Optional[torch.dtype] = None,
+ overlap_communication: bool = True,
+ custom_policy: Policy = None,
+ ) -> None:
super().__init__()
- assert dist.get_world_size() % (
- tp_size * pp_size
- ) == 0, f'world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}'
+ assert (
+ dist.get_world_size() % (tp_size * pp_size) == 0
+ ), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
if enable_sequence_parallelism:
- assert tp_size > 1, 'Sequence parallelism must be enabled when using tensor parallelism'
+ assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism"
self.tp_size = tp_size
self.pp_size = pp_size
@@ -324,26 +375,31 @@ def __init__(self,
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size)
self.stage_manager = None
self.schedule = None
+ self.custom_policy = custom_policy
assert zero_stage in (0, 1, 2)
if self.pp_size > 1:
- assert num_microbatches is not None or microbatch_size is not None, 'num_microbatches or microbatch_size must be specified when using pipeline parallelism'
- assert self.zero_stage <= 1, 'zero stage must be 0 or 1 when using pipeline parallelism'
+ assert (
+ num_microbatches is not None or microbatch_size is not None
+ ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
+ assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism"
self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS)
- self.schedule = OneForwardOneBackwardSchedule(self.stage_manager,
- num_microbatches=num_microbatches,
- microbatch_size=microbatch_size)
+ self.schedule = OneForwardOneBackwardSchedule(
+ self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size
+ )
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
- self.shard_config = ShardConfig(tensor_parallel_process_group=self.tp_group,
- pipeline_stage_manager=self.stage_manager,
- enable_tensor_parallelism=self.tp_size > 1,
- enable_all_optimization=self.enable_all_optimization,
- enable_fused_normalization=self.enable_fused_normalization,
- enable_flash_attention=self.enable_flash_attention,
- enable_jit_fused=self.enable_jit_fused,
- enable_sequence_parallelism=enable_sequence_parallelism,
- enable_sequence_overlap=enable_sequence_overlap)
+ self.shard_config = ShardConfig(
+ tensor_parallel_process_group=self.tp_group,
+ pipeline_stage_manager=self.stage_manager,
+ enable_tensor_parallelism=self.tp_size > 1,
+ enable_all_optimization=self.enable_all_optimization,
+ enable_fused_normalization=self.enable_fused_normalization,
+ enable_flash_attention=self.enable_flash_attention,
+ enable_jit_fused=self.enable_jit_fused,
+ enable_sequence_parallelism=enable_sequence_parallelism,
+ enable_sequence_overlap=enable_sequence_overlap,
+ )
self.amp_config = dict(
initial_scale=initial_scale,
growth_factor=growth_factor,
@@ -354,18 +410,22 @@ def __init__(self,
max_scale=max_scale,
)
- self.ddp_config = dict(broadcast_buffers=broadcast_buffers,
- bucket_cap_mb=ddp_bucket_cap_mb,
- find_unused_parameters=find_unused_parameters,
- check_reduction=check_reduction,
- gradient_as_bucket_view=gradient_as_bucket_view,
- static_graph=static_graph)
+ self.ddp_config = dict(
+ broadcast_buffers=broadcast_buffers,
+ bucket_cap_mb=ddp_bucket_cap_mb,
+ find_unused_parameters=find_unused_parameters,
+ check_reduction=check_reduction,
+ gradient_as_bucket_view=gradient_as_bucket_view,
+ static_graph=static_graph,
+ )
- self.zero_config = dict(reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024,
- communication_dtype=communication_dtype,
- overlap_communication=overlap_communication,
- cpu_offload=cpu_offload,
- partition_grad=(self.zero_stage == 2))
+ self.zero_config = dict(
+ reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024,
+ communication_dtype=communication_dtype,
+ overlap_communication=overlap_communication,
+ cpu_offload=cpu_offload,
+ partition_grad=(self.zero_stage == 2),
+ )
self.max_norm = max_norm
@@ -374,10 +434,10 @@ def enable_pipeline_parallelism(self) -> bool:
return self.pp_size > 1
def supported_devices(self) -> List[str]:
- return ['cuda']
+ return ["cuda"]
def supported_precisions(self) -> List[str]:
- return ['fp16', 'bf16', 'fp32']
+ return ["fp16", "bf16", "fp32"]
def control_device(self) -> bool:
return True
@@ -402,57 +462,62 @@ def configure(
param_info = get_param_info(optimizer)
if not isinstance(model, ModelWrapper):
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
- model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp,
- self.ddp_config)
+ model = HybridParallelModule(
+ model, self.precision, self.shard_config, self.dp_group, use_ddp, self.ddp_config, self.custom_policy
+ )
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if self.zero_stage == 0:
- if self.precision in ['fp16', 'bf16']:
- optimizer = HybridParallelAMPOptimizer(optimizer,
- model,
- use_pipeline=self.enable_pipeline_parallelism,
- param_info=param_info,
- precision=self.precision,
- max_norm=self.max_norm,
- **self.amp_config)
- self.checkpoint_io.link_master_and_working_param(optimizer.working_to_master_map,
- optimizer.master_to_working_map)
+ if self.precision in ["fp16", "bf16"]:
+ optimizer = HybridParallelAMPOptimizer(
+ optimizer,
+ model,
+ use_pipeline=self.enable_pipeline_parallelism,
+ param_info=param_info,
+ precision=self.precision,
+ max_norm=self.max_norm,
+ **self.amp_config,
+ )
else:
- optimizer = HybridParallelNaiveOptimizer(optimizer,
- model,
- use_pipeline=self.enable_pipeline_parallelism,
- param_info=param_info)
+ optimizer = HybridParallelNaiveOptimizer(
+ optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info
+ )
else:
assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
- assert self.precision != 'fp32', "Please set precision to 'fp16' or 'bf16' when using ZeRO."
- optimizer = HybridParallelZeroOptimizer(optimizer,
- model,
- use_pipeline=self.enable_pipeline_parallelism,
- param_info=param_info,
- dp_process_group=self.dp_group,
- tp_process_group=self.tp_group,
- verbose=True,
- clip_grad_norm=self.max_norm,
- **self.zero_config,
- **self.amp_config)
- self.checkpoint_io.link_master_and_working_param(optimizer._param_store.working_to_master_param,
- optimizer._param_store.master_to_working_param)
-
+ assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
+ optimizer = HybridParallelZeroOptimizer(
+ optimizer,
+ model,
+ use_pipeline=self.enable_pipeline_parallelism,
+ param_info=param_info,
+ dp_process_group=self.dp_group,
+ tp_process_group=self.tp_group,
+ verbose=True,
+ clip_grad_norm=self.max_norm,
+ **self.zero_config,
+ **self.amp_config,
+ )
+ # inject update_master_params
+ model.update_master_params = MethodType(optimizer.update_master_params, model)
return model, optimizer, criterion, dataloader, lr_scheduler
- def execute_pipeline(self,
- data_iter: Iterator,
- model: HybridParallelModule,
- criterion: Callable[[Any, Any], torch.Tensor],
- optimizer: Optional[Union[HybridParallelNaiveOptimizer, HybridParallelAMPOptimizer,
- HybridParallelZeroOptimizer]] = None,
- return_loss: bool = True,
- return_outputs: bool = False) -> dict:
- assert self.enable_pipeline_parallelism, 'pipeline parallelism is not enabled'
+ def execute_pipeline(
+ self,
+ data_iter: Iterator,
+ model: HybridParallelModule,
+ criterion: Callable[[Any, Any], torch.Tensor],
+ optimizer: Optional[
+ Union[HybridParallelNaiveOptimizer, HybridParallelAMPOptimizer, HybridParallelZeroOptimizer]
+ ] = None,
+ return_loss: bool = True,
+ return_outputs: bool = False,
+ ) -> dict:
+ assert self.enable_pipeline_parallelism, "pipeline parallelism is not enabled"
# return loss or outputs if needed
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
with ctx:
- outputs = self.schedule.forward_backward_step(model, data_iter, criterion, optimizer, return_loss,
- return_outputs)
+ outputs = self.schedule.forward_backward_step(
+ model, data_iter, criterion, optimizer, return_loss, return_outputs
+ )
model.sync_shared_params()
if isinstance(optimizer, HybridParallelZeroOptimizer):
optimizer.sync_grad()
@@ -460,15 +525,9 @@ def execute_pipeline(self,
model.sync_grads()
return outputs
- def prepare_dataloader(self,
- dataset,
- batch_size,
- shuffle=False,
- seed=1024,
- drop_last=False,
- pin_memory=False,
- num_workers=0,
- **kwargs):
+ def prepare_dataloader(
+ self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
+ ):
r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
`torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.
@@ -491,10 +550,9 @@ def prepare_dataloader(self,
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
"""
_kwargs = kwargs.copy()
- sampler = DistributedSampler(dataset,
- num_replicas=self.pg_mesh.size(DP_AXIS),
- rank=self.pg_mesh.coordinate(DP_AXIS),
- shuffle=shuffle)
+ sampler = DistributedSampler(
+ dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle
+ )
# Deterministic dataloader
def seed_worker(worker_id):
@@ -503,18 +561,19 @@ def seed_worker(worker_id):
torch.manual_seed(worker_seed)
random.seed(worker_seed)
- return DataLoader(dataset,
- batch_size=batch_size,
- sampler=sampler,
- worker_init_fn=seed_worker,
- drop_last=drop_last,
- pin_memory=pin_memory,
- num_workers=num_workers,
- **_kwargs)
+ return DataLoader(
+ dataset,
+ batch_size=batch_size,
+ sampler=sampler,
+ worker_init_fn=seed_worker,
+ drop_last=drop_last,
+ pin_memory=pin_memory,
+ num_workers=num_workers,
+ **_kwargs,
+ )
def get_checkpoint_io(self) -> CheckpointIO:
- self.checkpoint_io = HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
- return self.checkpoint_io
+ return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
def no_sync(self, model: Module) -> Iterator[None]:
raise NotImplementedError
diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py
index 9adb4beec9b9..0e515a55a8e3 100644
--- a/colossalai/booster/plugin/low_level_zero_plugin.py
+++ b/colossalai/booster/plugin/low_level_zero_plugin.py
@@ -1,14 +1,12 @@
import logging
import os
-import warnings
from functools import partial
from pathlib import Path
from types import MethodType
-from typing import Callable, Iterator, List, Optional, Tuple, Union
+from typing import Callable, Iterator, List, Optional, Tuple
import torch
import torch.nn as nn
-from torch import Tensor
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils._pytree import tree_map
@@ -24,7 +22,6 @@
save_param_groups,
save_state_dict,
sharded_optimizer_loading_epilogue,
- unwrap_optimizer,
)
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device
@@ -33,7 +30,7 @@
from .dp_plugin_base import DPPluginBase
from .torch_ddp_plugin import TorchDDPCheckpointIO
-__all__ = ['LowLevelZeroPlugin']
+__all__ = ["LowLevelZeroPlugin"]
def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
@@ -42,17 +39,16 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
return x
-SUPPORTED_PRECISION = ['fp16', 'bf16', 'fp32']
+SUPPORTED_PRECISION = ["fp16", "bf16", "fp32"]
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
-
def __init__(self, module: nn.Module, precision: str) -> None:
super().__init__(module)
self.dtype = None
- if precision == 'fp16':
+ if precision == "fp16":
self.dtype = torch.float16
- elif precision == 'bf16':
+ elif precision == "bf16":
self.dtype = torch.bfloat16
if self.dtype is not None:
module = module.to(self.dtype)
@@ -68,13 +64,8 @@ def forward(self, *args, **kwargs):
kwargs = tree_map(self.convert_fn, kwargs)
return super().forward(*args, **kwargs)
- def unwrap(self):
- # TODO(ver217): this is a workaround for loading model
- return self
-
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
-
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
"""Save optimizer to checkpoint but only on master process.
@@ -83,7 +74,7 @@ def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str,
checkpoint (str): Path to save checkpoint
gather_dtensor (bool): Whether to gather_dtensor, not used
"""
-
+ assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!"
# the `state_dict` in LowLevelZeroOptimizer has communication
# if only the master rank collect state_dict and save,
# the communication on each rank would not match
@@ -91,12 +82,14 @@ def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str,
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors=False)
- def save_sharded_optimizer(self,
- optimizer: OptimizerWrapper,
- checkpoint: str,
- gather_dtensor: bool = False,
- prefix: str = None,
- size_per_shard: int = 1024):
+ def save_sharded_optimizer(
+ self,
+ optimizer: OptimizerWrapper,
+ checkpoint: str,
+ gather_dtensor: bool = False,
+ prefix: str = None,
+ size_per_shard: int = 1024,
+ ):
"""
Save sharded Zero-optimizer checkpoint under the given checkpointing path.
The following files will be created under the path:
@@ -111,6 +104,7 @@ def save_sharded_optimizer(self,
prefix (str): Perfix of file to save
size_per_shard (int): Max file size of each file that store state tensors
"""
+ assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!"
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
@@ -148,9 +142,11 @@ def save_sharded_optimizer(self,
index_file.append_meta_data("total_size", total_size)
if self.coordinator.is_master():
index_file.write_index_file(save_index_file)
- logging.info(f"The optimizer is going to be split to checkpoint shards. "
- f"You can find where each parameters has been saved in the "
- f"index located at {save_index_file}.")
+ logging.info(
+ f"The optimizer is going to be split to checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {save_index_file}."
+ )
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str):
"""Load sharded optimizer with the given path to index file.
@@ -160,9 +156,8 @@ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: s
index_file_path (str): Path to the index file
prefix (str): Not used.
"""
- # If optimizer is wrapped, unwrap it.
- if isinstance(optimizer, OptimizerWrapper):
- optimizer = unwrap_optimizer(optimizer)
+ assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before Loading!"
+ optimizer = optimizer.unwrap()
# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
@@ -170,8 +165,10 @@ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: s
# Load param_groups
param_group_path = ckpt_index_file.get_param_group_filename()
if param_group_path is None:
- raise RuntimeError(f'Invalid index file path {index_file_path} for an optimizer. \
- Lacking param group file under current directory.')
+ raise RuntimeError(
+ f"Invalid index file path {index_file_path} for an optimizer. \
+ Lacking param group file under current directory."
+ )
id_map = load_param_groups_into_optimizer(optimizer, param_group_path)
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
@@ -181,9 +178,10 @@ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: s
# shard state dict
for param_idx, state in state_dict.items():
for k, v in state.items():
- if isinstance(v, torch.Tensor) and k != 'step':
- padding_size = (self.coordinator.world_size -
- v.numel() % self.coordinator.world_size) % self.coordinator.world_size
+ if isinstance(v, torch.Tensor) and k != "step":
+ padding_size = (
+ self.coordinator.world_size - v.numel() % self.coordinator.world_size
+ ) % self.coordinator.world_size
with torch.no_grad():
v = v.flatten()
if padding_size > 0:
@@ -191,38 +189,23 @@ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: s
v_list = v.split(v.numel() // self.coordinator.world_size)
state_dict[param_idx][k] = v_list[self.coordinator.rank].detach().clone()
load_states_into_optimizer(optimizer, state_dict, id_map)
-
sharded_optimizer_loading_epilogue(optimizer)
- def save_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, gather_dtensor: bool,
- use_safetensors: bool):
- assert isinstance(model, LowLevelZeroModel)
- super().save_unsharded_model(model.module, checkpoint, gather_dtensor, use_safetensors)
-
- def save_sharded_model(self,
- model: nn.Module,
- checkpoint_path: str,
- gather_dtensor: bool = True,
- prefix: Optional[str] = None,
- max_shard_size: int = 1024,
- use_safetensors: bool = False):
- assert isinstance(model, LowLevelZeroModel)
- super().save_sharded_model(model.module, checkpoint_path, gather_dtensor, prefix, max_shard_size,
- use_safetensors)
-
- def load_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, strict: bool = True):
- assert isinstance(model, LowLevelZeroModel)
- super().load_unsharded_model(model.module, checkpoint, strict)
+ def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
+ assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
+ super().load_unsharded_model(model, checkpoint, strict)
model.update_master_params()
- def load_sharded_model(self,
- model: LowLevelZeroModel,
- checkpoint_index_file: Path,
- strict: bool = False,
- use_safetensors: bool = False,
- load_sub_module: bool = True):
- assert isinstance(model, LowLevelZeroModel)
- super().load_sharded_model(model.module, checkpoint_index_file, strict, use_safetensors, load_sub_module)
+ def load_sharded_model(
+ self,
+ model: ModelWrapper,
+ checkpoint_index_file: Path,
+ strict: bool = False,
+ use_safetensors: bool = False,
+ load_sub_module: bool = True,
+ ):
+ assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
+ super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module)
model.update_master_params()
@@ -230,16 +213,17 @@ class LowLevelZeroPlugin(DPPluginBase):
"""
Plugin for low level zero.
- Example:
- >>> from colossalai.booster import Booster
- >>> from colossalai.booster.plugin import LowLevelZeroPlugin
- >>>
- >>> model, train_dataset, optimizer, criterion = ...
- >>> plugin = LowLevelZeroPlugin()
+ ```python
+ from colossalai.booster import Booster
+ from colossalai.booster.plugin import LowLevelZeroPlugin
- >>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
- >>> booster = Booster(plugin=plugin)
- >>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
+ model, train_dataset, optimizer, criterion = ...
+ plugin = LowLevelZeroPlugin()
+
+ train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
+ booster = Booster(plugin=plugin)
+ model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
+ ```
Args:
strage (int, optional): ZeRO stage. Defaults to 1.
@@ -264,7 +248,7 @@ class LowLevelZeroPlugin(DPPluginBase):
def __init__(
self,
stage: int = 1,
- precision: str = 'fp16',
+ precision: str = "fp16",
initial_scale: float = 2**32,
min_scale: float = 1,
growth_factor: float = 2,
@@ -281,9 +265,9 @@ def __init__(
verbose: bool = False,
) -> None:
super().__init__()
- assert stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training'
- assert precision in SUPPORTED_PRECISION, f'LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training'
- assert norm_type == 2.0, f'LowLevelZeroPlugin only supports norm_type=2.0 now'
+ assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
+ assert precision in SUPPORTED_PRECISION, f"LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training"
+ assert norm_type == 2.0, f"LowLevelZeroPlugin only supports norm_type=2.0 now"
self.stage = stage
self.precision = precision
self.zero_optim_kwargs = dict(
@@ -319,7 +303,7 @@ def control_device(self) -> bool:
return True
def supported_devices(self) -> List[str]:
- return ['cuda']
+ return ["cuda"]
def configure(
self,
@@ -329,15 +313,13 @@ def configure(
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
-
if not isinstance(model, ModelWrapper):
model = LowLevelZeroModel(model, self.precision)
- if optimizer is not None and \
- not isinstance(optimizer, OptimizerWrapper):
- optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(optimizer,
- **self.zero_optim_kwargs,
- verbose=self.verbose)
+ if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
+ optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(
+ optimizer, **self.zero_optim_kwargs, verbose=self.verbose
+ )
# inject update_master_params
model.update_master_params = MethodType(optimizer.update_master_params, model)
diff --git a/colossalai/booster/plugin/plugin_base.py b/colossalai/booster/plugin/plugin_base.py
index fb21e57f41f7..4e570cbe8abc 100644
--- a/colossalai/booster/plugin/plugin_base.py
+++ b/colossalai/booster/plugin/plugin_base.py
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
-from typing import Callable, Iterator, List, Optional, Tuple, Union
+from typing import Callable, Iterator, List, Optional, Tuple
import torch.nn as nn
from torch.optim import Optimizer
@@ -9,11 +9,10 @@
from colossalai.checkpoint_io import CheckpointIO
from colossalai.interface import OptimizerWrapper
-__all__ = ['Plugin']
+__all__ = ["Plugin"]
class Plugin(ABC):
-
@abstractmethod
def supported_devices(self) -> List[str]:
pass
@@ -51,33 +50,31 @@ def control_checkpoint_io(self) -> bool:
"""
Whether the plugin controls the checkpoint io
"""
- pass
@abstractmethod
def get_checkpoint_io(self) -> CheckpointIO:
"""
Get checkpoint io object for this plugin, only invoked when control_checkpoint_io is True.
"""
- pass
@abstractmethod
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
"""
Context manager to disable gradient synchronization.
"""
- pass
@abstractmethod
- def prepare_dataloader(self,
- dataset: Dataset,
- batch_size: int,
- shuffle: bool = False,
- seed: int = 1024,
- drop_last: bool = False,
- pin_memory: bool = False,
- num_workers: int = 0,
- **kwargs):
+ def prepare_dataloader(
+ self,
+ dataset: Dataset,
+ batch_size: int,
+ shuffle: bool = False,
+ seed: int = 1024,
+ drop_last: bool = False,
+ pin_memory: bool = False,
+ num_workers: int = 0,
+ **kwargs,
+ ):
"""Prepare a dataloader for distributed training. The dataloader will be wrapped by
`torch.utils.data.DataLoader`
"""
- pass
diff --git a/colossalai/booster/plugin/pp_plugin_base.py b/colossalai/booster/plugin/pp_plugin_base.py
index f52844db082f..3d91eb95b409 100644
--- a/colossalai/booster/plugin/pp_plugin_base.py
+++ b/colossalai/booster/plugin/pp_plugin_base.py
@@ -9,13 +9,14 @@
class PipelinePluginBase(Plugin):
-
@abstractmethod
- def execute_pipeline(self,
- data_iter: Iterator,
- model: ModelWrapper,
- criterion: Callable[[Any, Any], torch.Tensor],
- optimizer: Optional[OptimizerWrapper] = None,
- return_loss: bool = True,
- return_outputs: bool = False) -> dict:
+ def execute_pipeline(
+ self,
+ data_iter: Iterator,
+ model: ModelWrapper,
+ criterion: Callable[[Any, Any], torch.Tensor],
+ optimizer: Optional[OptimizerWrapper] = None,
+ return_loss: bool = True,
+ return_outputs: bool = False,
+ ) -> dict:
pass
diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py
index f3f779c88e42..738634473dbc 100644
--- a/colossalai/booster/plugin/torch_ddp_plugin.py
+++ b/colossalai/booster/plugin/torch_ddp_plugin.py
@@ -1,4 +1,4 @@
-from typing import Callable, Iterator, List, Optional, Tuple, Union
+from typing import Callable, Iterator, List, Optional, Tuple
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
@@ -12,33 +12,41 @@
from .dp_plugin_base import DPPluginBase
-__all__ = ['TorchDDPPlugin']
+__all__ = ["TorchDDPPlugin"]
class TorchDDPCheckpointIO(GeneralCheckpointIO):
-
def __init__(self) -> None:
super().__init__()
self.coordinator = DistCoordinator()
- def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
+ def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
"""
- Load model from checkpoint with automatic unwrapping.
+ Load model from checkpoint.
"""
- # the model should be unwrapped in self.load_model via ModelWrapper.unwrap
- return super().load_unsharded_model(model, checkpoint, strict=strict)
+ assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
+ super().load_unsharded_model(model.unwrap(), checkpoint, strict=strict)
- def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
+ def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
"""
Save model to checkpoint but only on master process.
"""
+ assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
if self.coordinator.is_master():
- super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
+ super().save_unsharded_model(model.unwrap(), checkpoint, gather_dtensor, use_safetensors)
- def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
+ def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
+ """
+ Load optimizer from checkpoint.
+ """
+ assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
+ super().load_unsharded_optimizer(optimizer, checkpoint)
+
+ def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
"""
Save optimizer to checkpoint but only on master process.
"""
+ assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
if self.coordinator.is_master():
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
@@ -49,34 +57,67 @@ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
if self.coordinator.is_master():
super().save_lr_scheduler(lr_scheduler, checkpoint)
- def save_sharded_model(self,
- model: nn.Module,
- checkpoint_path: str,
- gather_dtensor: bool = True,
- prefix: Optional[str] = None,
- max_shard_size: int = 1024,
- use_safetensors: bool = False):
+ def save_sharded_model(
+ self,
+ model: ModelWrapper,
+ checkpoint_path: str,
+ gather_dtensor: bool = True,
+ prefix: Optional[str] = None,
+ max_shard_size: int = 1024,
+ use_safetensors: bool = False,
+ ):
"""
Save model to checkpoint but only on master process.
"""
+ assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
if self.coordinator.is_master():
- super().save_sharded_model(model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors)
+ super().save_sharded_model(
+ model.unwrap(), checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors
+ )
- def save_sharded_optimizer(self,
- optimizer: Optimizer,
- checkpoint: str,
- gather_dtensor: bool = True,
- prefix: Optional[str] = None,
- size_per_shard: int = 1024):
+ def load_sharded_model(
+ self,
+ model: ModelWrapper,
+ checkpoint_index_file: str,
+ strict: bool = False,
+ use_safetensors: bool = False,
+ load_sub_module: bool = True,
+ ):
"""
- Save optimizer to checkpoint but only on master process.
+ Load model from sharded checkpoint.
+ """
+ assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
+ super().load_sharded_model(model.unwrap(), checkpoint_index_file, strict, use_safetensors, load_sub_module)
+
+ def save_sharded_optimizer(
+ self,
+ optimizer: OptimizerWrapper,
+ checkpoint: str,
+ gather_dtensor: bool = True,
+ prefix: Optional[str] = None,
+ size_per_shard: int = 1024,
+ ):
+ """
+ Save optimizer to sharded checkpoint but only on master process.
"""
+ assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
if self.coordinator.is_master():
- super().save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
+ super().save_sharded_optimizer(optimizer.unwrap(), checkpoint, gather_dtensor, prefix, size_per_shard)
+ def load_sharded_optimizer(
+ self,
+ optimizer: Optimizer,
+ index_file_path: str,
+ prefix: Optional[str] = None,
+ ):
+ """
+ Load optimizer from sharded checkpoint.
+ """
+ assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
+ super().load_sharded_optimizer(optimizer.unwrap(), index_file_path, prefix)
-class TorchDDPModel(ModelWrapper):
+class TorchDDPModel(ModelWrapper):
def __init__(self, module: nn.Module, *args, **kwargs) -> None:
super().__init__(module)
self.module = DDP(module, *args, **kwargs)
@@ -89,16 +130,17 @@ class TorchDDPPlugin(DPPluginBase):
"""
Plugin for PyTorch DDP.
- Example:
- >>> from colossalai.booster import Booster
- >>> from colossalai.booster.plugin import TorchDDPPlugin
- >>>
- >>> model, train_dataset, optimizer, criterion = ...
- >>> plugin = TorchDDPPlugin()
+ ```python
+ from colossalai.booster import Booster
+ from colossalai.booster.plugin import TorchDDPPlugin
- >>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
- >>> booster = Booster(plugin=plugin)
- >>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
+ model, train_dataset, optimizer, criterion = ...
+ plugin = TorchDDPPlugin()
+
+ train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
+ booster = Booster(plugin=plugin)
+ model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
+ ```
Args:
broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training. Defaults to True.
@@ -109,20 +151,24 @@ class TorchDDPPlugin(DPPluginBase):
static_graph (bool, optional): Whether to use static graph. Defaults to False.
"""
- def __init__(self,
- broadcast_buffers: bool = True,
- bucket_cap_mb: int = 25,
- find_unused_parameters: bool = False,
- check_reduction: bool = False,
- gradient_as_bucket_view: bool = False,
- static_graph: bool = False) -> None:
+ def __init__(
+ self,
+ broadcast_buffers: bool = True,
+ bucket_cap_mb: int = 25,
+ find_unused_parameters: bool = False,
+ check_reduction: bool = False,
+ gradient_as_bucket_view: bool = False,
+ static_graph: bool = False,
+ ) -> None:
super().__init__()
- self.ddp_kwargs = dict(broadcast_buffers=broadcast_buffers,
- bucket_cap_mb=bucket_cap_mb,
- find_unused_parameters=find_unused_parameters,
- check_reduction=check_reduction,
- gradient_as_bucket_view=gradient_as_bucket_view,
- static_graph=static_graph)
+ self.ddp_kwargs = dict(
+ broadcast_buffers=broadcast_buffers,
+ bucket_cap_mb=bucket_cap_mb,
+ find_unused_parameters=find_unused_parameters,
+ check_reduction=check_reduction,
+ gradient_as_bucket_view=gradient_as_bucket_view,
+ static_graph=static_graph,
+ )
def support_no_sync(self) -> bool:
return True
@@ -131,13 +177,13 @@ def control_precision(self) -> bool:
return False
def supported_precisions(self) -> List[str]:
- return ['fp16', 'fp16_apex', 'bf16', 'fp8']
+ return ["fp16", "fp16_apex", "bf16", "fp8"]
def control_device(self) -> bool:
return True
def supported_devices(self) -> List[str]:
- return ['cuda']
+ return ["cuda"]
def configure(
self,
@@ -156,8 +202,7 @@ def configure(
# wrap the model with PyTorch DDP
model = TorchDDPModel(model, **self.ddp_kwargs)
- if optimizer is not None and \
- not isinstance(optimizer, OptimizerWrapper):
+ if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
optimizer = OptimizerWrapper(optimizer)
return model, optimizer, criterion, dataloader, lr_scheduler
@@ -169,5 +214,5 @@ def get_checkpoint_io(self) -> CheckpointIO:
return TorchDDPCheckpointIO()
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
- assert isinstance(model, TorchDDPModel), 'Model is not boosted by TorchDDPPlugin.'
+ assert isinstance(model, TorchDDPModel), "Model is not boosted by TorchDDPPlugin."
return model.module.no_sync()
diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py
index fb7b5baadd0c..2ea7593a5cc5 100644
--- a/colossalai/booster/plugin/torch_fsdp_plugin.py
+++ b/colossalai/booster/plugin/torch_fsdp_plugin.py
@@ -1,13 +1,13 @@
import warnings
from pathlib import Path
-from typing import Callable, Iterable, Iterator, List, Optional, Tuple, Union
+from typing import Callable, Iterable, Iterator, List, Optional, Tuple
import torch
import torch.nn as nn
from packaging import version
from torch.distributed import ProcessGroup
-if version.parse(torch.__version__) >= version.parse('1.12.0'):
+if version.parse(torch.__version__) >= version.parse("1.12.0"):
from torch.distributed.fsdp import FullStateDictConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType
@@ -31,64 +31,77 @@
from .dp_plugin_base import DPPluginBase
-__all__ = ['TorchFSDPPlugin']
+__all__ = ["TorchFSDPPlugin"]
class TorchFSDPCheckpointIO(GeneralCheckpointIO):
-
def __init__(self) -> None:
super().__init__()
self.coordinator = DistCoordinator()
- def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
+ def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool):
+ assert isinstance(model, TorchFSDPModel), "Please boost the model before loading!"
+ model = model.unwrap()
checkpoint = utils.load_state_dict(checkpoint)
model.load_state_dict(checkpoint)
- def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
+ def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: Path):
+ assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before loading!"
checkpoint = utils.load_state_dict(checkpoint)
fsdp_model = optimizer.unwrap_model()
sharded_osd = FSDP.scatter_full_optim_state_dict(checkpoint, fsdp_model)
optimizer.load_state_dict(sharded_osd)
- def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
+ def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
"""
Save model to checkpoint but only on master process.
"""
- # the model should be unwrapped in self.load_model via ModelWrapper.unwrap
+ assert isinstance(model, TorchFSDPModel), "Please boost the model before saving!"
+ model = model.unwrap()
cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg):
full_model_state = model.state_dict()
utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors)
- def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
+ def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
"""
Save optimizer to checkpoint but only on master process.
"""
- assert isinstance(optimizer, FSDPOptimizerWrapper)
+ assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!"
fsdp_model = optimizer.unwrap_model()
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True)
utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False)
- def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: Optional[str],
- size_per_shard: int, use_safetensors: bool):
+ def save_sharded_model(
+ self,
+ model: nn.Module,
+ checkpoint: str,
+ gather_dtensor: bool,
+ prefix: Optional[str],
+ size_per_shard: int,
+ use_safetensors: bool,
+ ):
"""
Save model to checkpoint but only on master process.
"""
raise NotImplementedError("Sharded model checkpoint is not supported yet.")
- def load_sharded_model(self,
- model: nn.Module,
- checkpoint_index_file: Path,
- strict: bool = False,
- use_safetensors: bool = False,
- load_sub_module: bool = True):
+ def load_sharded_model(
+ self,
+ model: nn.Module,
+ checkpoint_index_file: Path,
+ strict: bool = False,
+ use_safetensors: bool = False,
+ load_sub_module: bool = True,
+ ):
"""
Load model to checkpoint but only on master process.
"""
raise NotImplementedError("Sharded model checkpoint is not supported yet.")
- def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str,
- size_per_shard: int):
+ def save_sharded_optimizer(
+ self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str, size_per_shard: int
+ ):
"""
Save optimizer to checkpoint but only on master process.
"""
@@ -109,7 +122,6 @@ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
class TorchFSDPModel(ModelWrapper):
-
def __init__(self, module: nn.Module, *args, **kwargs) -> None:
super().__init__(module)
self.module = FSDP(module, *args, **kwargs)
@@ -119,7 +131,6 @@ def unwrap(self):
class FSDPOptimizerWrapper(OptimizerWrapper):
-
def __init__(self, optimizer: Optimizer, model: nn.Module):
self.model = model
super().__init__(optimizer)
@@ -132,22 +143,23 @@ class TorchFSDPPlugin(DPPluginBase):
"""
Plugin for PyTorch FSDP.
- Example:
- >>> from colossalai.booster import Booster
- >>> from colossalai.booster.plugin import TorchFSDPPlugin
- >>>
- >>> model, train_dataset, optimizer, criterion = ...
- >>> plugin = TorchFSDPPlugin()
+ ```python
+ from colossalai.booster import Booster
+ from colossalai.booster.plugin import TorchFSDPPlugin
+
+ model, train_dataset, optimizer, criterion = ...
+ plugin = TorchFSDPPlugin()
- >>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8)
- >>> booster = Booster(plugin=plugin)
- >>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
+ train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8)
+ booster = Booster(plugin=plugin)
+ model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
+ ```
Args:
See https://pytorch.org/docs/stable/fsdp.html for details.
"""
- if version.parse(torch.__version__) >= version.parse('1.12.0'):
+ if version.parse(torch.__version__) >= version.parse("1.12.0"):
def __init__(
self,
@@ -162,15 +174,18 @@ def __init__(
sync_module_states: bool = False,
):
super().__init__()
- self.fsdp_kwargs = dict(process_group=process_group,
- sharding_strategy=sharding_strategy,
- cpu_offload=cpu_offload,
- auto_wrap_policy=auto_wrap_policy,
- backward_prefetch=backward_prefetch,
- mixed_precision=mixed_precision,
- ignored_modules=ignored_modules,
- param_init_fn=param_init_fn,
- sync_module_states=sync_module_states)
+ self.fsdp_kwargs = dict(
+ process_group=process_group,
+ sharding_strategy=sharding_strategy,
+ cpu_offload=cpu_offload,
+ auto_wrap_policy=auto_wrap_policy,
+ backward_prefetch=backward_prefetch,
+ mixed_precision=mixed_precision,
+ ignored_modules=ignored_modules,
+ param_init_fn=param_init_fn,
+ sync_module_states=sync_module_states,
+ )
+
else:
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
@@ -184,13 +199,13 @@ def control_precision(self) -> bool:
return True
def supported_precisions(self) -> List[str]:
- return ['fp16', 'bf16']
+ return ["fp16", "bf16"]
def control_device(self) -> bool:
return True
def supported_devices(self) -> List[str]:
- return ['cuda']
+ return ["cuda"]
def configure(
self,
@@ -200,14 +215,13 @@ def configure(
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
-
# wrap the model with PyTorch FSDP
fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs)
if optimizer is not None:
if len(optimizer.param_groups) > 1:
warnings.warn(
- 'TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used.'
+ "TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used."
)
optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults)
diff --git a/colossalai/checkpoint_io/__init__.py b/colossalai/checkpoint_io/__init__.py
index e1aa6543ef39..19b61730bded 100644
--- a/colossalai/checkpoint_io/__init__.py
+++ b/colossalai/checkpoint_io/__init__.py
@@ -3,4 +3,4 @@
from .hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO
from .index_file import CheckpointIndexFile
-__all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO', 'HybridParallelCheckpointIO']
+__all__ = ["CheckpointIO", "CheckpointIndexFile", "GeneralCheckpointIO", "HybridParallelCheckpointIO"]
diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py
index baff24e1cb25..780117598e18 100644
--- a/colossalai/checkpoint_io/checkpoint_io_base.py
+++ b/colossalai/checkpoint_io/checkpoint_io_base.py
@@ -11,7 +11,7 @@
from .utils import has_index_file
-__all__ = ['CheckpointIO']
+__all__ = ["CheckpointIO"]
class CheckpointIO(ABC):
@@ -61,10 +61,9 @@ class CheckpointIO(ABC):
# ======================================
# Public methods
# ======================================
- def load_model(self,
- model: Union[nn.Module, ModelWrapper],
- checkpoint: str,
- strict: bool = True) -> Union[nn.Module, ModelWrapper]:
+ def load_model(
+ self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True
+ ) -> Union[nn.Module, ModelWrapper]:
"""
Load model from checkpoint.
@@ -88,9 +87,6 @@ def load_model(self,
# return the origin model instead of the unwrapped model
origin_model = model
- if isinstance(model, ModelWrapper):
- model = model.unwrap()
-
if index_file_exists:
self.load_sharded_model(model, index_file_path, strict)
else:
@@ -98,14 +94,16 @@ def load_model(self,
return origin_model
- def save_model(self,
- model: Union[nn.Module, ModelWrapper],
- checkpoint: str,
- shard: bool = False,
- gather_dtensor: bool = True,
- prefix: str = None,
- size_per_shard: int = 1024,
- use_safetensors: bool = False):
+ def save_model(
+ self,
+ model: Union[nn.Module, ModelWrapper],
+ checkpoint: str,
+ shard: bool = False,
+ gather_dtensor: bool = True,
+ prefix: str = None,
+ size_per_shard: int = 1024,
+ use_safetensors: bool = False,
+ ):
"""
Save model to checkpoint.
@@ -133,9 +131,6 @@ def save_model(self,
use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved
"""
- if isinstance(model, ModelWrapper):
- model = model.unwrap()
-
if shard:
self.save_sharded_model(model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors)
else:
@@ -157,7 +152,7 @@ def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = No
if Path(checkpoint).is_dir() and not index_file_exists:
# if the checkpoint is a directory and there is no index file, raise error
- raise ValueError(f'Cannot find index file in {checkpoint}')
+ raise ValueError(f"Cannot find index file in {checkpoint}")
if index_file_exists:
# the existence of index file means it is a sharded checkpoint
@@ -165,13 +160,15 @@ def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = No
else:
self.load_unsharded_optimizer(optimizer, checkpoint)
- def save_optimizer(self,
- optimizer: Optimizer,
- checkpoint: str,
- shard: bool = False,
- gather_dtensor=True,
- prefix: str = None,
- size_per_shard: int = 1024):
+ def save_optimizer(
+ self,
+ optimizer: Optimizer,
+ checkpoint: str,
+ shard: bool = False,
+ gather_dtensor=True,
+ prefix: str = None,
+ size_per_shard: int = 1024,
+ ):
"""
Save optimizer to checkpoint. Optimizer states saving is not compatible with safetensors.
@@ -207,7 +204,6 @@ def load_sharded_model(self, model: nn.Module, index_file_path: str, strict: boo
strict (bool): whether to strictly enforce that the param name in
the checkpoint match the keys returned by this module's.
"""
- pass
@abstractmethod
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
@@ -220,11 +216,17 @@ def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
strict (bool): whether to strictly enforce that the param name in
the checkpoint match the keys returned by this module's.
"""
- pass
@abstractmethod
- def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: Optional[str],
- size_per_shard: int, use_safetensors: bool):
+ def save_sharded_model(
+ self,
+ model: nn.Module,
+ checkpoint: str,
+ gather_dtensor: bool,
+ prefix: Optional[str],
+ size_per_shard: int,
+ use_safetensors: bool,
+ ):
"""
Save model to sharded checkpoint.
@@ -236,7 +238,6 @@ def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor:
size_per_shard (int): size per shard in MB.
use_safetensors (bool): whether to use safe tensors.
"""
- pass
@abstractmethod
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
@@ -249,7 +250,6 @@ def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor
gather_dtensor (bool): whether to gather the distributed tensor to the first device.
use_safetensors (bool): whether to use safe tensors.
"""
- pass
# ========================================================
# Abstract methods for optimizer loading/saving implementation
@@ -265,7 +265,6 @@ def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, pre
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
prefix (str): prefix for the optimizer checkpoint.
"""
- pass
@abstractmethod
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
@@ -276,11 +275,11 @@ def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
optimizer (Optimizer): optimizer to be loaded.
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
"""
- pass
@abstractmethod
- def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str,
- size_per_shard: int):
+ def save_sharded_optimizer(
+ self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
+ ):
"""
Save optimizer to sharded checkpoint.
@@ -291,7 +290,6 @@ def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_
prefix (str): prefix for the optimizer checkpoint.
size_per_shard (int): size per shard in MB.
"""
- pass
@abstractmethod
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool):
@@ -303,7 +301,6 @@ def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gathe
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
gather_dtensor (bool): whether to gather the distributed tensor to the first device.
"""
- pass
# ============================================
# methods for loading and saving lr scheduler
diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py
index faaf1d22722a..a652d9b4538e 100644
--- a/colossalai/checkpoint_io/general_checkpoint_io.py
+++ b/colossalai/checkpoint_io/general_checkpoint_io.py
@@ -3,20 +3,16 @@
import os
from functools import reduce
from pathlib import Path
-from typing import Iterator, Optional, OrderedDict, Tuple
+from typing import Optional
-import torch.distributed as dist
import torch.nn as nn
from torch.optim import Optimizer
-from colossalai.interface import OptimizerWrapper
-
from .checkpoint_io_base import CheckpointIO
from .index_file import CheckpointIndexFile
from .utils import (
get_model_base_filenames,
get_optimizer_base_filenames,
- get_shard_filename,
is_safetensors_available,
load_param_groups_into_optimizer,
load_shard_state_dict,
@@ -30,10 +26,9 @@
shard_model_checkpoint,
shard_optimizer_checkpoint,
sharded_optimizer_loading_epilogue,
- unwrap_optimizer,
)
-__all__ = ['GeneralCheckpointIO']
+__all__ = ["GeneralCheckpointIO"]
class GeneralCheckpointIO(CheckpointIO):
@@ -60,18 +55,16 @@ def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, pre
Load sharded optimizer with the given path to index file.
"""
- # If optimizer is wrapped, unwrap it.
- if isinstance(optimizer, OptimizerWrapper):
- optimizer = unwrap_optimizer(optimizer)
-
# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
# Load param_groups
param_group_path = ckpt_index_file.get_param_group_filename()
if param_group_path is None:
- raise RuntimeError(f'Invalid index file path {index_file_path} for an optimizer. \
- Lacking param group file under current directory.')
+ raise RuntimeError(
+ f"Invalid index file path {index_file_path} for an optimizer. \
+ Lacking param group file under current directory."
+ )
id_map = load_param_groups_into_optimizer(optimizer, param_group_path)
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
@@ -98,10 +91,6 @@ def save_sharded_optimizer(
- Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way
"""
- # If optimizer is wrapped, unwrap it.
- if isinstance(optimizer, OptimizerWrapper):
- optimizer = unwrap_optimizer(optimizer)
-
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
@@ -123,19 +112,23 @@ def save_sharded_optimizer(
# Save shards of optimizer states.
# In general cases, is_master is set to True to get the right behavior.
- total_size = save_state_dict_shards(sharded_state_dict=sharded_state,
- checkpoint=checkpoint,
- index_file=index_file,
- base_filename=states_name,
- is_master=True,
- use_safetensors=False)
+ total_size = save_state_dict_shards(
+ sharded_state_dict=sharded_state,
+ checkpoint=checkpoint,
+ index_file=index_file,
+ base_filename=states_name,
+ is_master=True,
+ use_safetensors=False,
+ )
# Wrap up index file.
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
- logging.info(f"The optimizer is going to be split to checkpoint shards. "
- f"You can find where each parameters has been saved in the "
- f"index located at {save_index_file}.")
+ logging.info(
+ f"The optimizer is going to be split to checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {save_index_file}."
+ )
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
checkpoint = load_state_dict(checkpoint)
@@ -150,13 +143,15 @@ def save_unsharded_optimizer(
# TODO(FrankLeeeee): handle distributed tensors
save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False)
- def save_sharded_model(self,
- model: nn.Module,
- checkpoint_path: str,
- gather_dtensor: bool = False,
- prefix: Optional[str] = None,
- max_shard_size: int = 1024,
- use_safetensors: bool = False):
+ def save_sharded_model(
+ self,
+ model: nn.Module,
+ checkpoint_path: str,
+ gather_dtensor: bool = False,
+ prefix: Optional[str] = None,
+ max_shard_size: int = 1024,
+ use_safetensors: bool = False,
+ ):
"""
implement this method as it can be supported by Huggingface model,
save shard model, save model to multiple files
@@ -175,26 +170,32 @@ def save_sharded_model(self,
# Save shards of optimizer states.
# In general cases, is_master is set to True to get the right behavior.
- total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
- checkpoint=checkpoint_path,
- index_file=index_file,
- base_filename=weights_name,
- is_master=True,
- use_safetensors=use_safetensors)
+ total_size = save_state_dict_shards(
+ sharded_state_dict=state_dict_shard,
+ checkpoint=checkpoint_path,
+ index_file=index_file,
+ base_filename=weights_name,
+ is_master=True,
+ use_safetensors=use_safetensors,
+ )
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint_path, is_master=True)
- logging.info(f"The model is going to be split to checkpoint shards. "
- f"You can find where each parameters has been saved in the "
- f"index located at {save_index_file}.")
-
- def load_sharded_model(self,
- model: nn.Module,
- checkpoint_index_file: Path,
- strict: bool = False,
- use_safetensors: bool = False,
- load_sub_module: bool = True):
+ logging.info(
+ f"The model is going to be split to checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {save_index_file}."
+ )
+
+ def load_sharded_model(
+ self,
+ model: nn.Module,
+ checkpoint_index_file: Path,
+ strict: bool = False,
+ use_safetensors: bool = False,
+ load_sub_module: bool = True,
+ ):
"""
load shard model, load model from multiple files
"""
@@ -219,7 +220,11 @@ def load_sharded_model(self,
if strict:
remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys))
if len(remain_keys) > 0:
- error_msgs = 'Missing key(s) in state_dict: {}. '.format(', '.join(
- '"{}"'.format(k) for k in missing_keys))
- raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
- self.__class__.__name__, "\n\t".join(error_msgs)))
+ error_msgs = "Missing key(s) in state_dict: {}. ".format(
+ ", ".join('"{}"'.format(k) for k in missing_keys)
+ )
+ raise RuntimeError(
+ "Error(s) in loading state_dict for {}:\n\t{}".format(
+ self.__class__.__name__, "\n\t".join(error_msgs)
+ )
+ )
diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
index 6eee3ace0308..779ff42d75a1 100644
--- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
+++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
@@ -1,19 +1,18 @@
import copy
-import gc
import logging
import os
from pathlib import Path
from shutil import rmtree
-from typing import Dict, Iterator, Optional, OrderedDict, Tuple, Union
+from typing import Dict, Iterator, Optional, OrderedDict, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
-from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
-from colossalai.interface import OptimizerWrapper
+from colossalai.cluster import DistCoordinator
+from colossalai.interface import ModelWrapper, OptimizerWrapper
from .general_checkpoint_io import GeneralCheckpointIO
from .index_file import CheckpointIndexFile
@@ -24,19 +23,21 @@
get_optimizer_base_filenames,
is_safetensors_available,
load_shard_state_dict,
+ load_state_dict,
load_state_dict_into_model,
load_states_into_optimizer,
save_config_file,
save_param_groups,
+ save_state_dict,
save_state_dict_shards,
search_tp_partition_dim,
sharded_optimizer_loading_epilogue,
)
try:
- from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
+ from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
except ImportError:
- _EXTRA_STATE_KEY_SUFFIX = '_extra_state'
+ _EXTRA_STATE_KEY_SUFFIX = "_extra_state"
class HybridParallelCheckpointIO(GeneralCheckpointIO):
@@ -51,12 +52,14 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
verbose (bool, optional): Whether to print logging massage when saving/loading has been succesfully executed. Defaults to True.
"""
- def __init__(self,
- dp_group: ProcessGroup,
- pp_group: ProcessGroup,
- tp_group: ProcessGroup,
- zero_stage: int,
- verbose: bool = True) -> None:
+ def __init__(
+ self,
+ dp_group: ProcessGroup,
+ pp_group: ProcessGroup,
+ tp_group: ProcessGroup,
+ zero_stage: int,
+ verbose: bool = True,
+ ) -> None:
super().__init__()
self.dp_group = dp_group
self.pp_group = pp_group
@@ -67,16 +70,14 @@ def __init__(self,
self.dp_size = dist.get_world_size(dp_group)
self.pp_size = dist.get_world_size(pp_group)
self.tp_size = dist.get_world_size(tp_group)
- self.use_zero = (zero_stage > 0)
+ self.use_zero = zero_stage > 0
self.verbose = verbose
- self.working_to_master_map = None
- self.master_to_working_map = None
+ self.coordinator = DistCoordinator()
@staticmethod
- def _model_sharder(model: nn.Module,
- prefix: str = '',
- keep_vars: bool = False,
- size_per_shard: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
+ def _model_sharder(
+ model: nn.Module, prefix: str = "", keep_vars: bool = False, size_per_shard: int = 1024
+ ) -> Iterator[Tuple[OrderedDict, int]]:
# An internel method that breaks state_dict of model into shards within limited size.
state_dict_sharder = StateDictSharder(size_per_shard)
@@ -101,8 +102,10 @@ def _model_sharder(model: nn.Module,
# Save extra states.
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
- if getattr(model.__class__, "get_extra_state",
- torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
+ if (
+ getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
+ is not torch.nn.Module.get_extra_state
+ ):
extra_state = model.get_extra_state()
block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state)
if block is not None:
@@ -112,20 +115,20 @@ def _model_sharder(model: nn.Module,
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
@staticmethod
- def _optimizer_sharder(optimizer: OptimizerWrapper,
- use_zero: bool,
- dp_group: ProcessGroup,
- tp_group: ProcessGroup,
- master_to_working_map: Optional[Dict[int, torch.Tensor]] = None,
- size_per_shard: int = 1024):
-
+ def _optimizer_sharder(
+ optimizer: OptimizerWrapper,
+ use_zero: bool,
+ dp_group: ProcessGroup,
+ tp_group: ProcessGroup,
+ size_per_shard: int = 1024,
+ ):
# An internel method that breaks state_dict of optimizer into shards within limited size.
state_dict_sharder = StateDictSharder(size_per_shard)
param_info = optimizer.param_info
+ master_to_working_map = optimizer.get_master_to_working_map()
for param, state in optimizer.optim.state.items():
-
if param is None:
continue
@@ -134,15 +137,17 @@ def _optimizer_sharder(optimizer: OptimizerWrapper,
else:
working_param = param
- param_id = param_info['param2id'][id(working_param)]
- original_shape = param_info['param2shape'][id(working_param)]
- state_ = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state(state,
- working_param,
- original_shape=original_shape,
- dp_group=dp_group,
- tp_group=tp_group,
- use_zero=use_zero,
- inplace=False)
+ param_id = param_info["param2id"][id(working_param)]
+ original_shape = param_info["param2shape"][id(working_param)]
+ state_ = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state(
+ state,
+ working_param,
+ original_shape=original_shape,
+ dp_group=dp_group,
+ tp_group=tp_group,
+ use_zero=use_zero,
+ inplace=False,
+ )
block, block_size = state_dict_sharder.append_optim_state(param_id, state_)
if block is not None:
@@ -151,13 +156,15 @@ def _optimizer_sharder(optimizer: OptimizerWrapper,
# Return the last block in sharder.
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
- def save_sharded_model(self,
- model: nn.Module,
- checkpoint: str,
- gather_dtensor: bool = True,
- prefix: Optional[str] = None,
- size_per_shard: int = 1024,
- use_safetensors: bool = False) -> None:
+ def save_sharded_model(
+ self,
+ model: ModelWrapper,
+ checkpoint: str,
+ gather_dtensor: bool = True,
+ prefix: Optional[str] = None,
+ size_per_shard: int = 1024,
+ use_safetensors: bool = False,
+ ) -> None:
"""
Save sharded model checkpoint under the given checkpointing path.
The following files will be created under the path:
@@ -176,6 +183,9 @@ def save_sharded_model(self,
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
"""
+ assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
+ model = model.unwrap()
+
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
@@ -192,24 +202,28 @@ def save_sharded_model(self,
state_dict_shard = HybridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint)
- control_saving = (self.tp_rank == 0)
+ control_saving = self.tp_rank == 0
if self.pp_size == 1:
# When pipeline is not used, save the model shards as in general checkpointIO
- total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
- checkpoint=checkpoint,
- index_file=index_file,
- base_filename=weights_name,
- is_master=control_saving,
- use_safetensors=use_safetensors)
+ total_size = save_state_dict_shards(
+ sharded_state_dict=state_dict_shard,
+ checkpoint=checkpoint,
+ index_file=index_file,
+ base_filename=weights_name,
+ is_master=control_saving,
+ use_safetensors=use_safetensors,
+ )
if control_saving:
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint)
- if self.verbose:
- logging.info(f"The model is split into checkpoint shards. "
- f"You can find where each parameters has been saved in the "
- f"index located at {save_index_file}.")
+ if self.verbose and self.coordinator.is_master():
+ logging.info(
+ f"The model is split into checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {save_index_file}."
+ )
else:
# When pipeline is used, each stage produces its own shard files and index files.
@@ -226,15 +240,19 @@ def save_sharded_model(self,
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
save_index_file = os.path.join("tmp_index_files", save_index_file)
- total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
- checkpoint=checkpoint,
- index_file=index_file,
- base_filename=weights_name,
- is_master=control_saving,
- use_safetensors=use_safetensors,
- use_pp_format=True)
+ total_size = save_state_dict_shards(
+ sharded_state_dict=state_dict_shard,
+ checkpoint=checkpoint,
+ index_file=index_file,
+ base_filename=weights_name,
+ is_master=control_saving,
+ use_safetensors=use_safetensors,
+ use_pp_format=True,
+ )
if control_saving:
- assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0."
+ assert (
+ self.dp_rank == 0 and self.tp_rank == 0
+ ), "The saving process should have both dp_rank and tp_rank as 0."
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
else:
@@ -256,12 +274,14 @@ def save_sharded_model(self,
final_index_file.write_index_file(final_index_file_path)
save_config_file(model, checkpoint)
rmtree(tmp_index_file_folder)
- if self.verbose:
- logging.info(f"The model is split into checkpoint shards. "
- f"You can find where each parameters has been saved in the "
- f"index located at {final_index_file_path}.")
-
- def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False):
+ if self.verbose and self.coordinator.is_master():
+ logging.info(
+ f"The model is split into checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {final_index_file_path}."
+ )
+
+ def load_sharded_model(self, model: ModelWrapper, checkpoint_index_file: Path, strict: bool = False):
"""
Load sharded model with the given path to index file of checkpoint folder.
@@ -271,6 +291,9 @@ def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, stri
strict (bool, optional): For name matching during loading state_dict. Defaults to False.
This argument should be manually set to False since params on same device might be stored in different files.
"""
+ assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
+ model_before_wrapping = model # backup for model before wrapping
+ model = model.unwrap()
# Check whether the checkpoint uses safetensors.
use_safetensors = False
@@ -303,11 +326,9 @@ def _load(name: str):
state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
missing_keys = []
- load_state_dict_into_model(model,
- state_dict,
- missing_keys=missing_keys,
- strict=strict,
- load_sub_module=True)
+ load_state_dict_into_model(
+ model, state_dict, missing_keys=missing_keys, strict=strict, load_sub_module=True
+ )
loaded_file.add(filename)
# Load parameters.
@@ -317,45 +338,33 @@ def _load(name: str):
# Load buffers.
non_persistent_buffers = set()
for n, m in model.named_modules():
- non_persistent_buffers |= set('.'.join((n, b)) for b in m._non_persistent_buffers_set)
+ non_persistent_buffers |= set(".".join((n, b)) for b in m._non_persistent_buffers_set)
for name, buf in model.named_buffers():
if buf is not None and name not in non_persistent_buffers:
_load(name)
# Load extra states.
extra_state_key = _EXTRA_STATE_KEY_SUFFIX
- if getattr(model.__class__, "get_extra_state",
- torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
+ if (
+ getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
+ is not torch.nn.Module.get_extra_state
+ ):
_load(extra_state_key)
# Update master params if mixed-precision training is enabled.
- with torch.no_grad():
- if self.working_to_master_map is not None:
- for param in model.parameters():
- if (param is None) or (id(param) not in self.working_to_master_map):
- continue
- master_param = self.working_to_master_map[id(param)]
- if self.use_zero:
- # master_param is sharded under Zero setting
- padding_size = (self.dp_size - param.numel() % self.dp_size) % self.dp_size
- if padding_size > 0:
- padded_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size])
- else:
- padded_param = param.data.view(-1)
- sharded_param = padded_param.split(padded_param.numel() // self.dp_size)[self.dp_rank]
- master_param.data.copy_(sharded_param.data)
- else:
- master_param.data.copy_(param.data)
-
- if self.verbose:
+ model_before_wrapping.update_master_params()
+
+ if self.verbose and self.coordinator.is_master():
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
- def save_sharded_optimizer(self,
- optimizer: OptimizerWrapper,
- checkpoint: str,
- gather_dtensor: bool = True,
- prefix: Optional[str] = None,
- size_per_shard: int = 1024):
+ def save_sharded_optimizer(
+ self,
+ optimizer: OptimizerWrapper,
+ checkpoint: str,
+ gather_dtensor: bool = True,
+ prefix: Optional[str] = None,
+ size_per_shard: int = 1024,
+ ):
"""
Save sharded optimizer checkpoint under the given checkpointing path.
The following files will be created under the path:
@@ -372,6 +381,7 @@ def save_sharded_optimizer(self,
prefix (str): Perfix of file to save
size_per_shard (int): Max file size of each file shard that store state tensors
"""
+ assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
@@ -390,19 +400,21 @@ def save_sharded_optimizer(self,
use_zero=self.use_zero,
dp_group=self.dp_group,
tp_group=self.tp_group,
- master_to_working_map=self.master_to_working_map,
- size_per_shard=size_per_shard)
+ size_per_shard=size_per_shard,
+ )
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
index_file = CheckpointIndexFile(checkpoint)
- control_saving = (self.dp_rank == 0 and self.tp_rank == 0)
+ control_saving = self.dp_rank == 0 and self.tp_rank == 0
if self.pp_size == 1:
# When pipeline is not used, save the optimizer shards as in general checkpointIO
- total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
- checkpoint=checkpoint,
- index_file=index_file,
- base_filename=states_name,
- is_master=control_saving)
+ total_size = save_state_dict_shards(
+ sharded_state_dict=state_dict_shard,
+ checkpoint=checkpoint,
+ index_file=index_file,
+ base_filename=states_name,
+ is_master=control_saving,
+ )
if control_saving:
# Store param groups.
@@ -412,10 +424,12 @@ def save_sharded_optimizer(self,
# Store index file.
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
- if self.verbose:
- logging.info(f"The optimizer is going to be split to checkpoint shards. "
- f"You can find where each parameters has been saved in the "
- f"index located at {save_index_file}.")
+ if self.verbose and self.coordinator.is_master():
+ logging.info(
+ f"The optimizer is going to be split to checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {save_index_file}."
+ )
else:
# When pipeline is used, each stage produces its own shard files and index files.
@@ -431,15 +445,19 @@ def save_sharded_optimizer(self,
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
save_index_file = os.path.join("tmp_index_files", save_index_file)
- total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
- checkpoint=checkpoint,
- index_file=index_file,
- base_filename=states_name,
- is_master=control_saving,
- use_pp_format=True)
+ total_size = save_state_dict_shards(
+ sharded_state_dict=state_dict_shard,
+ checkpoint=checkpoint,
+ index_file=index_file,
+ base_filename=states_name,
+ is_master=control_saving,
+ use_pp_format=True,
+ )
if control_saving:
- assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0."
+ assert (
+ self.dp_rank == 0 and self.tp_rank == 0
+ ), "The saving process should have both dp_rank and tp_rank as 0."
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
else:
@@ -449,7 +467,6 @@ def save_sharded_optimizer(self,
# The global master rank integrates the index files and clean the folder.
if self.pp_rank == 0:
-
final_index_file = CheckpointIndexFile(checkpoint)
final_index_file.append_meta_data("total_size", 0)
@@ -467,10 +484,12 @@ def save_sharded_optimizer(self,
final_index_file.write_index_file(final_index_file_path)
rmtree(tmp_index_file_folder)
- if self.verbose:
- logging.info(f"The model is split into checkpoint shards. "
- f"You can find where each parameters has been saved in the "
- f"index located at {final_index_file_path}.")
+ if self.verbose and self.coordinator.is_master():
+ logging.info(
+ f"The model is split into checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {final_index_file_path}."
+ )
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""):
"""
@@ -481,53 +500,58 @@ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_f
checkpoint_index_file (str): Path to the index file of checkpointing folder.
prefix (str): Not used.
"""
+ assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
- def _get_param_id_from_optimizer_param(param: torch.Tensor,
- master_to_working_map: Optional[Dict[int, torch.Tensor]] = None):
+ def _get_param_id_from_optimizer_param(
+ param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
+ ):
if master_to_working_map is not None:
working_param = master_to_working_map[id(param)]
else:
working_param = param
- return optimizer.param_info['param2id'][id(working_param)]
+ return optimizer.param_info["param2id"][id(working_param)]
# id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects.
# When Zero is used, the mapped parameter objects should be fp32 master parameters.
# IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info.
id_map = {}
+ master_to_working_map = optimizer.get_master_to_working_map()
for pg in optimizer.optim.param_groups:
- for param in pg['params']:
- param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map)
+ for param in pg["params"]:
+ param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
id_map[param_id] = param
# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
ckpt_root_path = ckpt_index_file.root_path
weight_map = ckpt_index_file.weight_map
- weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int
+ weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int
# Load param_groups
param_group_path = ckpt_index_file.get_param_group_filename()
if param_group_path is None:
- raise RuntimeError(f'Invalid index file path {checkpoint_index_file} for an optimizer. \
- Lacking param group file under current directory.')
+ raise RuntimeError(
+ f"Invalid index file path {checkpoint_index_file} for an optimizer. \
+ Lacking param group file under current directory."
+ )
saved_groups = torch.load(param_group_path)
updated_groups = []
for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
# obtain updated param group
new_pg = copy.deepcopy(saved_pg)
- new_pg['params'] = old_pg['params'] # The parameters in the same group shouln't change.
+ new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change.
updated_groups.append(new_pg)
- optimizer.optim.__dict__.update({'param_groups': updated_groups})
+ optimizer.optim.__dict__.update({"param_groups": updated_groups})
# Load saved states to optimizer.
# Keep a record of loaded files so that file will not be repeatedly loaded.
loaded_file = set()
for pg in optimizer.optim.param_groups:
- for param in pg['params']:
+ for param in pg["params"]:
if param is None:
continue
- param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map)
+ param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
if param_id not in weight_map:
continue
filename = weight_map[param_id]
@@ -544,80 +568,226 @@ def _get_param_id_from_optimizer_param(param: torch.Tensor,
# Then shard the loaded optimizer states if using tp/zero.
for param, state in optimizer.optim.state.items():
device = param.device
- if self.master_to_working_map is not None:
- working_param = self.master_to_working_map[id(param)]
+ if master_to_working_map is not None:
+ working_param = master_to_working_map[id(param)]
else:
working_param = param
- original_shape = optimizer.param_info['param2shape'][id(working_param)]
- sharded_state = self.shard_from_complete_optimizer_state(state,
- current_shape=working_param.shape,
- original_shape=original_shape,
- device=device,
- inplace=True)
+ original_shape = optimizer.param_info["param2shape"][id(working_param)]
+ sharded_state = self.shard_from_complete_optimizer_state(
+ state, current_shape=working_param.shape, original_shape=original_shape, device=device, inplace=True
+ )
optimizer.optim.state[param] = sharded_state
sharded_optimizer_loading_epilogue(optimizer.optim)
- if self.verbose:
+ if self.verbose and self.coordinator.is_master():
logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
- def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
- # TODO(Baizhou): support this feature after implementing complete state_dict collection
- raise NotImplementedError
+ def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
+ """
+ Save model state dict to a single file with given checkpointing path.
- def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
- # TODO(Baizhou): support this feature after implementing complete state_dict collection
- raise NotImplementedError
+ Args:
+ model (nn.Module): Model on local device to be saved.
+ checkpoint (str): Checkpointing path which should be a file path. Can be absolute or relative path.
+ gather_dtensor (bool, optional): Whether to gather dtensor, currently not used. Defaults to True.
+ use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
+ """
+ if self.coordinator.is_master():
+ logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
+
+ assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
+ model = model.unwrap()
- def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
- # TODO(Baizhou): support this feature after implementing complete state_dict collection
- raise NotImplementedError
+ if self.dp_rank != 0:
+ return
- def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
- # TODO(Baizhou): support this feature after implementing complete state_dict collection
- raise NotImplementedError
+ # The logic of collecting parameter shards along tp degree
+ # has been implemented by _save_to_state_dict method of ParallelModule in Shardformer.
+ state_dict = model.state_dict()
- def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
+ if self.pp_size == 1:
+ # When pipeline is not used, let master rank directly save the collected state_dict.
+ if self.tp_rank == 0:
+ save_state_dict(state_dict, checkpoint, use_safetensors)
+ else:
+ # When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict.
+ state_dict_list = [None for _ in range(self.pp_size)]
+ dist.barrier(self.pp_group)
+ dist.all_gather_object(state_dict_list, state_dict, self.pp_group)
+
+ # Only the master rank do the saving.
+ if self.coordinator.is_master():
+ complete_state_dict = dict()
+ for _state_dict in state_dict_list:
+ complete_state_dict.update(_state_dict)
+ save_state_dict(complete_state_dict, checkpoint, use_safetensors)
+
+ def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = False):
"""
- Save lr scheduler to checkpoint but only on master process.
+ Load model from a single file with the given path of checkpoint.
+
+ Args:
+ model (nn.Module): The model to be loaded.
+ checkpoint_index_file (str): Path to the checkpoint file.
+ strict (bool, optional): For name matching during loading state_dict. Defaults to False.
+ This argument should be manually set to False since not all params in checkpoint are needed for each device when pipeline is enabled.
"""
if self.coordinator.is_master():
- super().save_lr_scheduler(lr_scheduler, checkpoint)
+ logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
+
+ assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
+ strict = False
+ model_before_wrapping = model
+ model = model.unwrap()
+
+ # Load from checkpoint. Since the logic of breaking parameter shards along tp degree
+ # has been implemented by _load_from_state_dict method of ParallelModule in Shardformer,
+ # model.load_state_dict can be directly called.
+ state_dict = load_state_dict(checkpoint)
+ model.load_state_dict(state_dict, strict=strict)
+
+ # Update master params if mixed-precision training is enabled.
+ model_before_wrapping.update_master_params()
+
+ def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
+ """
+ Save optimizer state dict to a file with given path.
+
+ Args:
+ optimizer (OptimizerWrapper): Optimizer to save sharded state_dict.
+ checkpoint (str): Path to save optimizer state_dict.
+ gather_dtensor (bool): Whether to gather_dtensor, not used.
+ """
+ if self.coordinator.is_master():
+ logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
+
+ assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
+
+ # optimizer states of parameters kept by local device('s pipeline stage)
+ local_states = dict()
+
+ for param, state in optimizer.optim.state.items():
+ if param is None:
+ continue
+
+ # working param is needed for obtaining correct param_id
+ master_to_working_map = optimizer.get_master_to_working_map()
+ if master_to_working_map is not None:
+ working_param = master_to_working_map[id(param)]
+ else:
+ working_param = param
+
+ # gather complete state from tp shards & dp shards
+ param_id = optimizer.param_info["param2id"][id(working_param)]
+ original_shape = optimizer.param_info["param2shape"][id(working_param)]
+ local_states[param_id] = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state(
+ state,
+ working_param,
+ original_shape=original_shape,
+ dp_group=self.dp_group,
+ tp_group=self.tp_group,
+ use_zero=self.use_zero,
+ inplace=False,
+ device=torch.device("cuda"),
+ )
- def link_master_and_working_param(self, working_to_master_map: Dict[Union[int, torch.Tensor], torch.Tensor],
- master_to_working_map: Dict[Union[int, torch.Tensor], torch.Tensor]):
+ if self.pp_size == 1:
+ # When pipeline is not used, let master rank directly save the collected state_dict.
+ state_dict = {"param_groups": optimizer.param_info["param_groups"], "state": local_states}
+ if self.coordinator.is_master():
+ save_state_dict(state_dict, checkpoint, use_safetensors=False)
+ else:
+ # When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict.
+ states_list = [None for _ in range(self.pp_size)]
+ dist.barrier(self.pp_group)
+ dist.all_gather_object(states_list, local_states, self.pp_group)
+
+ # Only the master rank do the saving.
+ if self.coordinator.is_master():
+ state_dict = {"param_groups": optimizer.param_info["param_groups"], "state": dict()}
+ for _states in states_list:
+ state_dict["state"].update(_states)
+ save_state_dict(state_dict, checkpoint, use_safetensors=False)
+
+ def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
"""
- Create mappings between working params (for forward/backward) and master params (for optimizer update) with passed in mappings.
- This mapping can only be created when mixied precision is used.
- The created mappings should be mappings from integer parameter addresses to parameter objects.
+ Load optimizer from a file with given path.
Args:
- working_to_master_map (Dict[Union[int, torch.Tensor], torch.Tensor]): A mapping from working parameters objects/addresses to master parameter objects.
- master_to_working_map (Dict[Union[int, torch.Tensor], torch.Tensor]): A mapping from master parameters objects/addresses to working parameter objects.
+ optimizer (OptimizerWrapper): The optimizer to be loaded.
+ checkpoint_index_file (str): Path to the checkpoint file.
"""
- self.working_to_master_map = dict()
- for k, v in working_to_master_map.items():
- if isinstance(k, torch.Tensor):
- self.working_to_master_map[id(k)] = v
- elif isinstance(k, int):
- self.working_to_master_map[k] = v
+
+ def _get_param_id_from_optimizer_param(
+ param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
+ ):
+ if master_to_working_map is not None:
+ working_param = master_to_working_map[id(param)]
else:
- raise ValueError(
- f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!")
-
- self.master_to_working_map = dict()
- for k, v in master_to_working_map.items():
- if isinstance(k, torch.Tensor):
- self.master_to_working_map[id(k)] = v
- elif isinstance(k, int):
- self.master_to_working_map[k] = v
+ working_param = param
+ return optimizer.param_info["param2id"][id(working_param)]
+
+ if self.coordinator.is_master():
+ logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
+
+ assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
+
+ # Complete optimizer state_dict loaded from checkpoint, need to be processed later.
+ state_dict = load_state_dict(checkpoint)
+
+ # Load param_groups.
+ updated_groups = []
+ saved_groups = state_dict["param_groups"]
+ for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
+ new_pg = copy.deepcopy(saved_pg)
+ new_pg["params"] = old_pg["params"] # Only keep the parameters kept by current pipeline stage.
+ updated_groups.append(new_pg)
+ optimizer.optim.__dict__.update({"param_groups": updated_groups})
+
+ # Load saved states to optimizer. First discard those states not belonging to current pipeline stage.
+ master_to_working_map = optimizer.get_master_to_working_map()
+ id_map = {}
+ for pg in optimizer.optim.param_groups:
+ for param in pg["params"]:
+ param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
+ id_map[param_id] = param
+ load_states_into_optimizer(optimizer.optim, state_dict["state"], id_map, strict=True)
+
+ # Then shard the loaded optimizer states if using tp/zero.
+ for param, state in optimizer.optim.state.items():
+ if param is None:
+ continue
+ device = param.device
+ if master_to_working_map is not None:
+ working_param = master_to_working_map[id(param)]
else:
- raise ValueError(
- f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!")
+ working_param = param
+ original_shape = optimizer.param_info["param2shape"][id(working_param)]
+ sharded_state = self.shard_from_complete_optimizer_state(
+ state, current_shape=working_param.shape, original_shape=original_shape, device=device, inplace=True
+ )
+ optimizer.optim.state[param] = sharded_state
+
+ sharded_optimizer_loading_epilogue(optimizer.optim)
+
+ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
+ """
+ Save lr scheduler to checkpoint but only on master process.
+ """
+ if self.coordinator.is_master():
+ super().save_lr_scheduler(lr_scheduler, checkpoint)
@staticmethod
- def gather_from_sharded_optimizer_state(state: OrderedDict, param: torch.Tensor, original_shape: torch.Size,
- dp_group: ProcessGroup, tp_group: ProcessGroup, use_zero: bool,
- inplace: bool) -> OrderedDict:
+ def gather_from_sharded_optimizer_state(
+ state: OrderedDict,
+ param: torch.Tensor,
+ original_shape: torch.Size,
+ dp_group: ProcessGroup,
+ tp_group: ProcessGroup,
+ use_zero: bool,
+ inplace: bool,
+ device: torch.device = torch.device("cpu"),
+ ) -> OrderedDict:
"""
With given parameter and its optimizer states, gather the complete optimizer state for saving.
@@ -629,6 +799,7 @@ def gather_from_sharded_optimizer_state(state: OrderedDict, param: torch.Tensor,
tp_group (ProcessGroup): The process group of tensor parallel.
use_zero (bool): Whether Zero is used.
inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
+ device (torch.device): The destination device of loaded optimizer states. Defaults to torch.device('cpu').
Returns:
OrderedDict: The complete optimizer state of given parameter.
@@ -639,14 +810,13 @@ def gather_from_sharded_optimizer_state(state: OrderedDict, param: torch.Tensor,
state_ = state if inplace else copy.deepcopy(state)
for k, v in state_.items():
- if isinstance(v, torch.Tensor) and k != 'step':
-
+ if isinstance(v, torch.Tensor) and k != "step":
# First gather Zero shards.
if use_zero:
v = v.cuda()
gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)]
dist.all_gather(gather_tensor, v, group=dp_group)
- v = torch.stack(gather_tensor).view(-1)[:param.numel()].reshape_as(param)
+ v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
# Then gather TP shards.
partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size)
@@ -655,13 +825,18 @@ def gather_from_sharded_optimizer_state(state: OrderedDict, param: torch.Tensor,
dist.all_gather(gather_tensor, v, group=tp_group)
v = torch.cat(gather_tensor, dim=partition_dim)
- state_[k] = v.detach().clone().cpu()
+ state_[k] = v.detach().clone().to(device)
return state_
- def shard_from_complete_optimizer_state(self, state: OrderedDict, current_shape: torch.Size,
- original_shape: torch.Size, device: torch.device,
- inplace: bool) -> OrderedDict:
+ def shard_from_complete_optimizer_state(
+ self,
+ state: OrderedDict,
+ current_shape: torch.Size,
+ original_shape: torch.Size,
+ device: torch.device,
+ inplace: bool,
+ ) -> OrderedDict:
"""
With complete optimizer states of a specific parameter loaded from checkpoint,
slice out the sharded optimizer states kept by current device.
@@ -679,8 +854,7 @@ def shard_from_complete_optimizer_state(self, state: OrderedDict, current_shape:
state_ = state if inplace else copy.deepcopy(state)
for k, v in state_.items():
- if isinstance(v, torch.Tensor) and k != 'step':
-
+ if isinstance(v, torch.Tensor) and k != "step":
# Shard state along tensor parallel group.
partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size)
if partition_dim is not None:
diff --git a/colossalai/checkpoint_io/index_file.py b/colossalai/checkpoint_io/index_file.py
index 388cf3fbe9bb..da12c146f2c3 100644
--- a/colossalai/checkpoint_io/index_file.py
+++ b/colossalai/checkpoint_io/index_file.py
@@ -6,7 +6,7 @@
from .utils import is_dtensor_checkpoint
-__all__ = ['CheckpointIndexFile']
+__all__ = ["CheckpointIndexFile"]
class CheckpointIndexFile:
@@ -50,7 +50,7 @@ def load(self, json_path: str):
json_path (str): path to the json file.
"""
# load the json file
- with open(json_path, 'r') as f:
+ with open(json_path, "r") as f:
index = json.load(f)
# assign attributes if exists
@@ -75,7 +75,7 @@ def export(self, json_path: str):
index["weight_map"] = self.weight_map
# export the index file
- with open(json_path, 'w') as f:
+ with open(json_path, "w") as f:
json.dump(index, f, indent=4)
def append_weight_map(self, param_name: str, shard_file: str):
diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py
index 3441eca38ce7..d2f4a0bcacf8 100644
--- a/colossalai/checkpoint_io/utils.py
+++ b/colossalai/checkpoint_io/utils.py
@@ -1,5 +1,4 @@
# coding=utf-8
-import copy
import os
import re
from collections import abc as container_abcs
@@ -12,8 +11,6 @@
import torch.nn as nn
from torch.optim import Optimizer
-from colossalai.interface import ModelWrapper, OptimizerWrapper
-from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.tensor.d_tensor import (
is_customized_distributed_tensor,
is_distributed_tensor,
@@ -56,7 +53,6 @@ def is_safetensors_available() -> bool:
bool: whether safetensors is available.
"""
try:
- import safetensors
return True
except ImportError:
return False
@@ -72,7 +68,7 @@ def is_dtensor_checkpoint(checkpoint_file_path: str) -> bool:
Returns:
bool: whether the checkpoint file is a dtensor checkpoint.
"""
- if checkpoint_file_path.endswith('.*.safetensors') or checkpoint_file_path.endswith('.*.bin'):
+ if checkpoint_file_path.endswith(".*.safetensors") or checkpoint_file_path.endswith(".*.bin"):
return True
else:
return False
@@ -88,7 +84,7 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
Returns:
bool: whether the checkpoint file is a safetensor checkpoint.
"""
- if checkpoint_file_path.endswith('.safetensors'):
+ if checkpoint_file_path.endswith(".safetensors"):
return True
else:
return False
@@ -114,8 +110,9 @@ def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Siz
partition_dim = dim
break
if partition_dim is not None:
- assert original_shape[partition_dim] == tp_size * current_shape[partition_dim], \
- f"The parameter isn't evenly distributed among tensor parallel group: \
+ assert (
+ original_shape[partition_dim] == tp_size * current_shape[partition_dim]
+ ), f"The parameter isn't evenly distributed among tensor parallel group: \
shape before sharding {original_shape}, shape after sharding {current_shape}"
return partition_dim
@@ -124,28 +121,15 @@ def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Siz
# ======================================
# Helper classes and functions for saving shard file
# ======================================
-def unwrap_optimizer(optimizer: OptimizerWrapper):
- '''
- Unwrap a wrapped optimizer.
- This method should be used before saving/loading it to/from sharded checkpoints.
- '''
-
- # TODO(Baizhou): ColossalaiOptimizer will be replaced with OptimizerWrapper in the future
- unwrapped_optim = optimizer.optim
- if isinstance(unwrapped_optim, ColossalaiOptimizer):
- unwrapped_optim = unwrapped_optim.optim
- return unwrapped_optim
class StateDictSharder:
-
def __init__(self, size_per_shard: int) -> None:
self.max_shard_size = size_per_shard
self.current_block = OrderedDict()
self.current_block_size = 0
def append_param(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]:
-
tensor_size = calculate_tensor_size(tensor)
ret_block = None
ret_block_size = 0
@@ -163,13 +147,11 @@ def append_param(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[Ordere
return ret_block, ret_block_size
def append_optim_state(self, param_id: int, state: OrderedDict) -> Tuple[Optional[OrderedDict], int]:
-
# A state might contain more than one tensors.
# e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq'
state_size = 0
isDTensor = False
for state_tensor in state.values():
-
# When state_tensor is not of Tensor class,
# e.g., a SGD optimizer with momentum set to 0 can have None as state
# The calculation of tensor size should be skipped to avoid error.
@@ -221,14 +203,16 @@ def gather_distributed_param(param: torch.Tensor, keep_vars: bool = False) -> to
return param_
-def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]],
- checkpoint: str,
- index_file: "CheckpointIndexFile",
- base_filename: str,
- is_master: bool,
- use_safetensors: bool = False,
- use_pp_format: bool = False) -> int:
- '''
+def save_state_dict_shards(
+ sharded_state_dict: Iterator[Tuple[OrderedDict, int]],
+ checkpoint: str,
+ index_file: "CheckpointIndexFile",
+ base_filename: str,
+ is_master: bool,
+ use_safetensors: bool = False,
+ use_pp_format: bool = False,
+) -> int:
+ """
Save sharded state dict only on master rank, this method can be used by both model and optimizer states.
Args:
sharded_state_dict (Iterator[Tuple[OrderedDict, int]]): a generator of shards, each shard contains state dict and shard size.
@@ -241,7 +225,7 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]]
Returns:
int: the total size of shards
- '''
+ """
total_size = 0
shard_filenames = []
@@ -292,7 +276,7 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
"""
# Only split state_dict['state']; state_dict['param_group'] is not considered in this function.
- states = state_dict['state']
+ states = state_dict["state"]
state_dict_sharder = StateDictSharder(max_shard_size)
for param_id, state in states.items():
@@ -320,9 +304,11 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors
"""
if use_safetensors:
assert is_safetensors_available(), "safetensors is not available."
- assert checkpoint_file_path.endswith('.safetensors'), \
- "safetensors only supports .safetensors suffix for checkpoint file."
+ assert checkpoint_file_path.endswith(
+ ".safetensors"
+ ), "safetensors only supports .safetensors suffix for checkpoint file."
from safetensors.torch import save_file as safe_save_file
+
safe_save_file(state_dict, checkpoint_file_path, metadata={"format": "pt"})
else:
torch.save(state_dict, checkpoint_file_path)
@@ -340,11 +326,13 @@ def save_param_groups(state_dict: dict, group_file_path: str) -> None:
torch.save(param_groups, group_file_path)
-def clean_folder(checkpoint_path: str,
- weights_name: str,
- shard_filenames: List[str],
- is_master: bool = True,
- use_pp_format: bool = False):
+def clean_folder(
+ checkpoint_path: str,
+ weights_name: str,
+ shard_filenames: List[str],
+ is_master: bool = True,
+ use_pp_format: bool = False,
+):
"""
Clean the unneeded files in checkpoint directory after shards of state_dict have been saved.
@@ -366,8 +354,12 @@ def clean_folder(checkpoint_path: str,
else:
# When this checkpoint is created by pipeline parallel process, the pattern is a little different.
reg = re.compile(r"(.*?)-stage-\d{5}-shard-\d{5}")
- if (filename.startswith(weights_no_suffix) and os.path.isfile(full_filename)
- and filename not in shard_filenames and reg.fullmatch(filename_no_suffix) is not None):
+ if (
+ filename.startswith(weights_no_suffix)
+ and os.path.isfile(full_filename)
+ and filename not in shard_filenames
+ and reg.fullmatch(filename_no_suffix) is not None
+ ):
os.remove(full_filename)
@@ -416,7 +408,7 @@ def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFi
size_per_shard (int): size per shard in MB.
"""
root_path = index_file.root_path
- output_root_path = root_path.joinpath('dtensor')
+ output_root_path = root_path.joinpath("dtensor")
# create directory
output_root_path.mkdir(exist_ok=True)
@@ -436,7 +428,7 @@ def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFi
# update the weight map
# * means all shards
- ckpt_file_name_in_weight_map = 'dtensor/' + generate_dtensor_file_name(name, '*', use_safetensors)
+ ckpt_file_name_in_weight_map = "dtensor/" + generate_dtensor_file_name(name, "*", use_safetensors)
index_file.append_weight_map(name, ckpt_file_name_in_weight_map)
@@ -451,15 +443,14 @@ def get_checkpoint_file_suffix(use_safetensors: bool) -> str:
str: checkpoint file suffix.
"""
if use_safetensors:
- return '.safetensors'
+ return ".safetensors"
else:
- return '.bin'
+ return ".bin"
-def generate_checkpoint_shard_file_name(index: int,
- total_number: int,
- use_safetensors: bool,
- prefix: str = None) -> str:
+def generate_checkpoint_shard_file_name(
+ index: int, total_number: int, use_safetensors: bool, prefix: str = None
+) -> str:
"""
Generate checkpoint shard file name.
@@ -493,7 +484,7 @@ def generate_dtensor_file_name(param_name: str, index: int, use_safetensors: boo
str: dtensor file name.
"""
suffix = get_checkpoint_file_suffix(use_safetensors)
- return f'{param_name}.{index}.{suffix}'
+ return f"{param_name}.{index}.{suffix}"
# ========================================
@@ -510,21 +501,21 @@ def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False):
if use_safetensors:
from safetensors.torch import load_file as safe_load_file
from safetensors.torch import safe_open
+
with safe_open(checkpoint_file, framework="pt") as f:
metadata = f.metadata()
if metadata["format"] != "pt":
raise NotImplementedError(
- f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet.")
+ f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet."
+ )
return safe_load_file(checkpoint_file)
else:
- return torch.load(checkpoint_file, map_location=torch.device('cpu'))
+ return torch.load(checkpoint_file, map_location=torch.device("cpu"))
-def load_state_dict_into_model(model: nn.Module,
- state_dict: torch.Tensor,
- missing_keys: List,
- strict: bool = False,
- load_sub_module: bool = True):
+def load_state_dict_into_model(
+ model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False, load_sub_module: bool = True
+):
r"""Copies parameters and buffers from :attr:`state_dict` into
this module and its descendants.
@@ -540,7 +531,7 @@ def load_state_dict_into_model(model: nn.Module,
error_msgs: List[str] = []
# copy state_dict so _load_from_state_dict can modify it
- metadata = getattr(state_dict, '_metadata', None)
+ metadata = getattr(state_dict, "_metadata", None)
state_dict = OrderedDict(state_dict)
if metadata is not None:
state_dict._metadata = metadata
@@ -564,10 +555,12 @@ def load(module: nn.Module, state_dict, prefix="", load_sub_module: bool = True)
if strict:
if len(unexpected_keys) > 0:
- error_msgs = 'Unexpected key(s) in state_dict: {}. '.format(', '.join(
- '"{}"'.format(k) for k in unexpected_keys))
- raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
- model.__class__.__name__, "\n\t".join(error_msgs)))
+ error_msgs = "Unexpected key(s) in state_dict: {}. ".format(
+ ", ".join('"{}"'.format(k) for k in unexpected_keys)
+ )
+ raise RuntimeError(
+ "Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
+ )
def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str) -> dict:
@@ -577,9 +570,9 @@ def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str
# Load list of param_groups from given file path.
# The params in saved_groups are in the form of integer indices.
- saved_groups = torch.load(param_group_path, map_location=torch.device('cpu'))
+ saved_groups = torch.load(param_group_path, map_location=torch.device("cpu"))
if not isinstance(saved_groups, List):
- raise ValueError(f'The param_groups saved at {param_group_path} is not of List type')
+ raise ValueError(f"The param_groups saved at {param_group_path} is not of List type")
# The params in param_groups are in the form of pytorch tensors.
# For more details, please view source code of Optimizer class in pytorch.
@@ -588,26 +581,30 @@ def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str
# Check the compatibility of saved_groups and param_groups.
if len(param_groups) != len(saved_groups):
raise ValueError("loaded state dict has a different number of original parameter groups")
- param_lens = (len(g['params']) for g in param_groups)
- saved_lens = (len(g['params']) for g in saved_groups)
+ param_lens = (len(g["params"]) for g in param_groups)
+ saved_lens = (len(g["params"]) for g in saved_groups)
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
- raise ValueError("loaded state dict contains a parameter group "
- "that doesn't match the size of optimizer's group")
+ raise ValueError(
+ "loaded state dict contains a parameter group " "that doesn't match the size of optimizer's group"
+ )
# Creating mapping from id to parameters.
id_map = {
- old_id: p for old_id, p in zip(chain.from_iterable((g['params'] for g in saved_groups
- )), chain.from_iterable((g['params'] for g in param_groups)))
+ old_id: p
+ for old_id, p in zip(
+ chain.from_iterable((g["params"] for g in saved_groups)),
+ chain.from_iterable((g["params"] for g in param_groups)),
+ )
}
# Update parameter groups, setting their 'params' value.
def update_group(group, new_group):
- new_group['params'] = group['params']
+ new_group["params"] = group["params"]
return new_group
updated_groups = [update_group(g, ng) for g, ng in zip(param_groups, saved_groups)]
- optimizer.__dict__.update({'param_groups': updated_groups})
+ optimizer.__dict__.update({"param_groups": updated_groups})
return id_map
@@ -632,7 +629,7 @@ def cast(param, value, key=None):
# Floating-point types are a bit special here. They are the only ones
# that are assumed to always match the type of params.
# Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424
- if (key != "step"):
+ if key != "step":
if param.is_floating_point():
value = value.to(param.dtype)
value = value.to(param.device)
@@ -666,8 +663,8 @@ def sharded_optimizer_loading_epilogue(optimizer: Optimizer):
"""
# Do the cleaning up as in src code of Pytorch.
- optimizer._hook_for_profile() # To support multiprocessing pickle/unpickle.
- optimizer.defaults.setdefault('differentiable', False)
+ optimizer._hook_for_profile() # To support multiprocessing pickle/unpickle.
+ optimizer.defaults.setdefault("differentiable", False)
def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]:
@@ -690,20 +687,20 @@ def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]:
return False, None
elif checkpoint_path.is_dir():
# check if there is only one a file ending with .index.json in this directory
- index_files = list(checkpoint_path.glob('*.index.*json'))
+ index_files = list(checkpoint_path.glob("*.index.*json"))
# if we found a .index.json file, make sure there is only one
if len(index_files) > 0:
- assert len(
- index_files
- ) == 1, f'Expected to find one .index.json file in {checkpoint_path}, but found {len(index_files)}'
+ assert (
+ len(index_files) == 1
+ ), f"Expected to find one .index.json file in {checkpoint_path}, but found {len(index_files)}"
if len(index_files) == 1:
return True, index_files[0]
else:
return False, None
else:
- raise RuntimeError(f'Invalid checkpoint path {checkpoint_path}. Expected a file or a directory.')
+ raise RuntimeError(f"Invalid checkpoint path {checkpoint_path}. Expected a file or a directory.")
def load_state_dict(checkpoint_file_path: Path):
@@ -717,14 +714,17 @@ def load_state_dict(checkpoint_file_path: Path):
dict: state dict.
"""
- assert not is_dtensor_checkpoint(checkpoint_file_path), \
- f'Cannot load state dict from dtensor checkpoint {checkpoint_file_path}, you should convert the distributed tensors to gathered tensors with our CLI offline.'
+ assert not is_dtensor_checkpoint(
+ checkpoint_file_path
+ ), f"Cannot load state dict from dtensor checkpoint {checkpoint_file_path}, you should convert the distributed tensors to gathered tensors with our CLI offline."
if is_safetensor_checkpoint(checkpoint_file_path):
- assert is_safetensors_available(), \
- f'Cannot load state dict from safetensor checkpoint {checkpoint_file_path}, because safetensors is not available. Please install safetensors first with pip install safetensors.'
+ assert (
+ is_safetensors_available()
+ ), f"Cannot load state dict from safetensor checkpoint {checkpoint_file_path}, because safetensors is not available. Please install safetensors first with pip install safetensors."
# load with safetensors
from safetensors import safe_open
+
state_dict = {}
with safe_open(checkpoint_file_path, framework="pt", device="cpu") as f:
for k in f.keys():
@@ -733,7 +733,7 @@ def load_state_dict(checkpoint_file_path: Path):
else:
# load with torch
- return torch.load(checkpoint_file_path, map_location=torch.device('cpu'))
+ return torch.load(checkpoint_file_path, map_location=torch.device("cpu"))
def add_prefix(weights_name: str, prefix: Optional[str] = None) -> str:
diff --git a/colossalai/cli/__init__.py b/colossalai/cli/__init__.py
index 658e35e4c72e..c7cb19c19308 100644
--- a/colossalai/cli/__init__.py
+++ b/colossalai/cli/__init__.py
@@ -1,3 +1,3 @@
from .cli import cli
-__all__ = ['cli']
+__all__ = ["cli"]
diff --git a/colossalai/cli/benchmark/__init__.py b/colossalai/cli/benchmark/__init__.py
deleted file mode 100644
index 618ff8c61dd4..000000000000
--- a/colossalai/cli/benchmark/__init__.py
+++ /dev/null
@@ -1,28 +0,0 @@
-import click
-
-from colossalai.context import Config
-
-from .benchmark import run_benchmark
-from .utils import *
-
-__all__ = ['benchmark']
-
-
-@click.command()
-@click.option("-g", "--gpus", type=int, default=None, help="Total number of devices to use.")
-@click.option("-b", "--batch_size", type=int, default=8, help="Batch size of the input tensor.")
-@click.option("-s", "--seq_len", type=int, default=512, help="Sequence length of the input tensor.")
-@click.option("-d", "--dimension", type=int, default=1024, help="Hidden dimension of the input tensor.")
-@click.option("-w", "--warmup_steps", type=int, default=10, help="The number of warmup steps.")
-@click.option("-p", "--profile_steps", type=int, default=50, help="The number of profiling steps.")
-@click.option("-l", "--layers", type=int, default=2)
-@click.option("-m",
- "--model",
- type=click.Choice(['mlp'], case_sensitive=False),
- default='mlp',
- help="Select the model to benchmark, currently only supports MLP")
-def benchmark(gpus: int, batch_size: int, seq_len: int, dimension: int, warmup_steps: int, profile_steps: int,
- layers: int, model: str):
- args_dict = locals()
- args = Config(args_dict)
- run_benchmark(args)
diff --git a/colossalai/cli/benchmark/benchmark.py b/colossalai/cli/benchmark/benchmark.py
deleted file mode 100644
index 97a9f45722dd..000000000000
--- a/colossalai/cli/benchmark/benchmark.py
+++ /dev/null
@@ -1,105 +0,0 @@
-from functools import partial
-from typing import Dict, List
-
-import click
-import torch.multiprocessing as mp
-
-import colossalai
-from colossalai.cli.benchmark.utils import find_all_configs, get_batch_data, profile_model
-from colossalai.context import Config
-from colossalai.context.random import reset_seeds
-from colossalai.core import global_context as gpc
-from colossalai.logging import disable_existing_loggers, get_dist_logger
-from colossalai.testing import free_port
-from colossalai.utils import MultiTimer
-
-from .models import MLP
-
-
-def run_benchmark(args: Config) -> None:
- """
- Run benchmarking with torch.multiprocessing.
- """
-
- # sanity checks
- if args.gpus is None:
- click.echo("Error: --num_gpus is not given")
- exit()
- if args.gpus <= 1:
- click.echo("Warning: tensor parallel will be activated with at least 2 devices.")
-
- click.echo("=== Benchmarking Parameters ===")
- for k, v in args.items():
- click.echo(f'{k}: {v}')
- click.echo('')
-
- config_list = find_all_configs(args.gpus)
-
- avail_ports = [free_port() for _ in range(len(config_list))]
- run_func = partial(run_dist_profiling,
- world_size=args.gpus,
- port_list=avail_ports,
- config_list=config_list,
- hyperparams=args)
- mp.spawn(run_func, nprocs=args.gpus)
-
-
-def run_dist_profiling(rank: int, world_size: int, port_list: List[int], config_list: List[Dict],
- hyperparams: Config) -> None:
- """
- A function executed for profiling, this function should be spawn by torch.multiprocessing.
-
- Args:
- rank (int): rank of the process
- world_size (int): the number of processes
- port_list (List[int]): a list of free ports for initializing distributed networks
- config_list (List[Dict]): a list of configuration
- hyperparams (Config): the hyperparameters given by the user
-
- """
-
- # disable logging for clean output
- disable_existing_loggers()
- logger = get_dist_logger()
- logger.set_level('WARNING')
-
- for config, port in zip(config_list, port_list):
- colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
- timer = MultiTimer()
-
- # 1D parallel should be skipped if in_features or out_features is not able to be divided exactly by 1D parallel size.
- if config.parallel.tensor.mode == '1d' and hyperparams.dimension % config.parallel.tensor.size != 0:
- click.echo(
- "1D parallel will be skipped because in_features or out_features is not able to be divided exactly by 1D parallel size."
- )
- continue
-
- if hyperparams.model == 'mlp':
- model = MLP(dim=hyperparams.dimension, layers=hyperparams.layers)
- else:
- if gpc.get_global_rank() == 0:
- click.echo("Error: Invalid argument for --model")
- exit()
-
- data_func = partial(get_batch_data,
- dim=hyperparams.dimension,
- batch_size=hyperparams.batch_size,
- seq_length=hyperparams.seq_len,
- mode=config.parallel.tensor.mode)
-
- fwd_time, bwd_time, max_allocated, max_cached = profile_model(model=model,
- warmup_steps=hyperparams.warmup_steps,
- profile_steps=hyperparams.profile_steps,
- data_func=data_func,
- timer=timer)
-
- gpc.destroy()
- reset_seeds()
-
- if gpc.get_global_rank() == 0:
- config_str = ', '.join([f'{k}: {v}' for k, v in config.parallel.tensor.items()])
- click.echo(f"=== {config_str} ===")
- click.echo(f"Average forward time: {fwd_time}")
- click.echo(f"Average backward time: {bwd_time}")
- click.echo(f"Max allocated GPU memory: {max_allocated}")
- click.echo(f"Max cached GPU memory: {max_cached}\n")
diff --git a/colossalai/cli/benchmark/models.py b/colossalai/cli/benchmark/models.py
deleted file mode 100644
index 385b485b6016..000000000000
--- a/colossalai/cli/benchmark/models.py
+++ /dev/null
@@ -1,18 +0,0 @@
-import torch
-
-import colossalai.legacy.nn as col_nn
-
-
-class MLP(torch.nn.Module):
-
- def __init__(self, dim: int, layers: int):
- super().__init__()
- self.layers = torch.nn.ModuleList()
-
- for _ in range(layers):
- self.layers.append(col_nn.Linear(dim, dim))
-
- def forward(self, x):
- for layer in self.layers:
- x = layer(x)
- return x
diff --git a/colossalai/cli/benchmark/utils.py b/colossalai/cli/benchmark/utils.py
deleted file mode 100644
index ee7d92d6ea6a..000000000000
--- a/colossalai/cli/benchmark/utils.py
+++ /dev/null
@@ -1,159 +0,0 @@
-import math
-import time
-from typing import Callable, Dict, List, Tuple
-
-import torch
-
-from colossalai.context import Config, ParallelMode
-from colossalai.utils import MultiTimer
-
-
-def get_time_stamp() -> int:
- """
- Return the time stamp for profiling.
-
- Returns:
- time_stamp (int): the time given by time.time()
- """
-
- torch.cuda.synchronize()
- time_stamp = time.time()
- return time_stamp
-
-
-def get_memory_states() -> Tuple[float]:
- """
- Return the memory statistics.
-
- Returns:
- max_allocated (float): the allocated CUDA memory
- max_cached (float): the cached CUDA memory
- """
-
- max_allocated = torch.cuda.max_memory_allocated() / (1024**3)
- max_cached = torch.cuda.max_memory_reserved() / (1024**3)
- torch.cuda.reset_peak_memory_stats()
- torch.cuda.empty_cache()
- return max_allocated, max_cached
-
-
-def find_all_configs(device_cnt: int) -> List[Dict]:
- """
- Find all possible configurations for tensor parallelism
-
- Args:
- device_cnt (int): the number of devices
-
- Returns:
- config_list (List[Dict]): a list of configurations
- """
-
- def _is_square(num):
- # 2D parallel should be implemented with at least 2 devices.
- if num <= 1:
- return False
- return math.floor(math.sqrt(num))**2 == num
-
- def _is_cube(num):
- # 3D parallel should be implemented with at least 2 devices.
- if num <= 1:
- return False
- return math.floor(num**(1. / 3.))**3 == num
-
- config_list = []
-
- # add non-parallel config
- config = dict(parallel=dict(tensor=dict(size=device_cnt, mode=None)))
- config_list.append(config)
-
- # add 1D config
- config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='1d')))
- config_list.append(config)
-
- # add 2D config only if device_cnt is a square
- if _is_square(device_cnt):
- config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='2d')))
- config_list.append(config)
-
- # check for 2.5D
- # iterate over depth
- for depth in range(1, device_cnt):
- if device_cnt % depth == 0 and _is_square(device_cnt // depth):
- config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='2.5d', depth=depth)))
- config_list.append(config)
-
- # check for 3D if device_cnt is a cube
- if _is_cube(device_cnt):
- config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='3d')))
- config_list.append(config)
-
- config_list = [Config(cfg) for cfg in config_list]
- return config_list
-
-
-def profile_model(model: torch.nn.Module, warmup_steps: int, profile_steps: int, data_func: Callable,
- timer: MultiTimer) -> Tuple[float]:
- """
- Profile the forward and backward of a model
-
- Args:
- model (torch.nn.Module): a PyTorch model
- warmup_steps (int): the number of steps for warmup
- profile_steps (int): the number of steps for profiling
- data_func (Callable): a function to generate random data
- timer (colossalai.utils.Multitimer): a timer instance for time recording
-
- Returns:
- fwd_time (float): the average forward time taken by forward pass in second
- bwd_time (float): the average backward time taken by forward pass in second
- max_allocated (float): the maximum GPU memory allocated in GB
- max_cached (float): the maximum GPU memory cached in GB
- """
-
- def _run_step(data):
- timer.start('forward')
- out = model(data)
- timer.stop('forward', keep_in_history=True)
- timer.start('backward')
- out.mean().backward()
- timer.stop('backward', keep_in_history=True)
-
- data_list = [data_func() for _ in range(warmup_steps)]
- for data in data_list:
- _run_step(data)
- timer.reset('forward')
- timer.reset('backward')
-
- for _ in range(profile_steps):
- data = data_func()
- _run_step(data)
-
- max_allocated, max_cached = get_memory_states()
- fwd_time = timer.get_timer('forward').get_history_mean()
- bwd_time = timer.get_timer('backward').get_history_mean()
- return fwd_time, bwd_time, max_allocated, max_cached
-
-
-def get_batch_data(dim: int, batch_size: int, seq_length: int, mode: ParallelMode) -> torch.Tensor:
- """
- Return a random data of shape (batch_size, seq_length, dim) for profiling.
-
- Args:
- dim (int): hidden size
- batch_size (int): the number of data samples
- seq_length (int): the number of tokens
- mode (ParallelMode): Colossal-AI ParallelMode enum
-
- Returns:
- data (torch.Tensor): random data
- """
-
- if mode in ['2d', '2.5d']:
- batch_size = batch_size // 2
- dim = dim // 2
- elif mode == '3d':
- batch_size = batch_size // 4
- dim = dim // 2
-
- data = torch.rand(batch_size, seq_length, dim).cuda()
- return data
diff --git a/colossalai/cli/check/__init__.py b/colossalai/cli/check/__init__.py
index a86b32bb6a18..7c26ab6ade6c 100644
--- a/colossalai/cli/check/__init__.py
+++ b/colossalai/cli/check/__init__.py
@@ -1,11 +1,12 @@
import click
+
from .check_installation import check_installation
-__all__ = ['check']
+__all__ = ["check"]
@click.command(help="Check if Colossal-AI is correct based on the given option")
-@click.option('-i', '--installation', is_flag=True, help="Check if Colossal-AI is built correctly")
+@click.option("-i", "--installation", is_flag=True, help="Check if Colossal-AI is built correctly")
def check(installation):
if installation:
check_installation()
diff --git a/colossalai/cli/check/check_installation.py b/colossalai/cli/check/check_installation.py
index 4a481f3bd122..772c513ffa06 100644
--- a/colossalai/cli/check/check_installation.py
+++ b/colossalai/cli/check/check_installation.py
@@ -9,7 +9,7 @@
def to_click_output(val):
# installation check output to understandable symbols for readability
- VAL_TO_SYMBOL = {True: u'\u2713', False: 'x', None: 'N/A'}
+ VAL_TO_SYMBOL = {True: "\u2713", False: "x", None: "N/A"}
if val in VAL_TO_SYMBOL:
return VAL_TO_SYMBOL[val]
@@ -55,8 +55,8 @@ def check_installation():
else:
torch_compatibility = _is_compatible([torch_version, prebuilt_torch_version_required])
- click.echo(f'#### Installation Report ####')
- click.echo(f'\n------------ Environment ------------')
+ click.echo(f"#### Installation Report ####")
+ click.echo(f"\n------------ Environment ------------")
click.echo(f"Colossal-AI version: {to_click_output(colossalai_version)}")
click.echo(f"PyTorch version: {to_click_output(torch_version)}")
click.echo(f"System CUDA version: {to_click_output(cuda_version)}")
@@ -69,7 +69,7 @@ def check_installation():
f"3. If the CUDA version required by PyTorch is N/A, you probably did not install a CUDA-compatible PyTorch. This value is give by torch.version.cuda and you can go to https://pytorch.org/get-started/locally/ to download the correct version."
)
- click.echo(f'\n------------ CUDA Extensions AOT Compilation ------------')
+ click.echo(f"\n------------ CUDA Extensions AOT Compilation ------------")
click.echo(f"Found AOT CUDA Extension: {to_click_output(found_aot_cuda_ext)}")
click.echo(f"PyTorch version used for AOT compilation: {to_click_output(prebuilt_torch_version_required)}")
click.echo(f"CUDA version used for AOT compilation: {to_click_output(prebuilt_cuda_version_required)}")
@@ -81,7 +81,7 @@ def check_installation():
click.echo(f"2. If AOT compilation is not enabled, stay calm as the CUDA kernels can still be built during runtime")
click.echo(f"\n------------ Compatibility ------------")
- click.echo(f'PyTorch version match: {to_click_output(torch_compatibility)}')
+ click.echo(f"PyTorch version match: {to_click_output(torch_compatibility)}")
click.echo(f"System and PyTorch CUDA version match: {to_click_output(sys_torch_cuda_compatibility)}")
click.echo(f"System and Colossal-AI CUDA version match: {to_click_output(sys_colossalai_cuda_compatibility)}")
click.echo(f"")
@@ -106,12 +106,12 @@ def _is_compatible(versions):
return False
# split version into [major, minor, patch]
- versions = [version.split('.') for version in versions]
+ versions = [version.split(".") for version in versions]
for version in versions:
if len(version) == 2:
# x means unknown
- version.append('x')
+ version.append("x")
for idx, version_values in enumerate(zip(*versions)):
equal = len(set(version_values)) == 1
@@ -137,11 +137,11 @@ def _parse_colossalai_version():
# 1. X.X.X+torchX.XXcuXX.X (when colossalai is installed with CUDA extensions)
# 2. X.X.X (when colossalai is not installed with CUDA extensions)
# where X represents an integer.
- colossalai_version = colossalai.__version__.split('+')[0]
+ colossalai_version = colossalai.__version__.split("+")[0]
try:
- torch_version_for_aot_build = colossalai.__version__.split('torch')[1].split('cu')[0]
- cuda_version_for_aot_build = colossalai.__version__.split('cu')[1]
+ torch_version_for_aot_build = colossalai.__version__.split("torch")[1].split("cu")[0]
+ cuda_version_for_aot_build = colossalai.__version__.split("cu")[1]
except:
torch_version_for_aot_build = None
cuda_version_for_aot_build = None
@@ -156,7 +156,6 @@ def _check_aot_built_cuda_extension_installed():
JIT (just-in-time) compilation will build CUDA extensions to `~/.cache/colossalai/torch_extensions` during runtime.
"""
try:
- import colossalai._C.fused_optim
found_aot_cuda_ext = True
except ImportError:
found_aot_cuda_ext = False
@@ -175,14 +174,14 @@ def _check_torch_version():
# torch version can be of two formats
# - 1.13.1+cu113
# - 1.13.1.devxxx
- torch_version = torch.__version__.split('+')[0]
- torch_version = '.'.join(torch_version.split('.')[:3])
+ torch_version = torch.__version__.split("+")[0]
+ torch_version = ".".join(torch_version.split(".")[:3])
# get cuda version in pytorch build
try:
torch_cuda_major = torch.version.cuda.split(".")[0]
torch_cuda_minor = torch.version.cuda.split(".")[1]
- torch_cuda_version = f'{torch_cuda_major}.{torch_cuda_minor}'
+ torch_cuda_version = f"{torch_cuda_major}.{torch_cuda_minor}"
except:
torch_cuda_version = None
@@ -208,7 +207,7 @@ def _check_cuda_version():
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
- cuda_version = f'{bare_metal_major}.{bare_metal_minor}'
+ cuda_version = f"{bare_metal_major}.{bare_metal_minor}"
except:
cuda_version = None
return cuda_version
diff --git a/colossalai/cli/cli.py b/colossalai/cli/cli.py
index a94e1150e49f..0d94fe59f8ae 100644
--- a/colossalai/cli/cli.py
+++ b/colossalai/cli/cli.py
@@ -1,12 +1,10 @@
import click
-from .benchmark import benchmark
from .check import check
from .launcher import run
-class Arguments():
-
+class Arguments:
def __init__(self, arg_dict):
for k, v in arg_dict.items():
self.__dict__[k] = v
@@ -19,7 +17,6 @@ def cli():
cli.add_command(run)
cli.add_command(check)
-cli.add_command(benchmark)
-if __name__ == '__main__':
+if __name__ == "__main__":
cli()
diff --git a/colossalai/cli/launcher/__init__.py b/colossalai/cli/launcher/__init__.py
index 808e4e84574f..0f9ead6495db 100644
--- a/colossalai/cli/launcher/__init__.py
+++ b/colossalai/cli/launcher/__init__.py
@@ -5,56 +5,81 @@
from .run import launch_multi_processes
-@click.command(help="Launch distributed training on a single node or multiple nodes",
- context_settings=dict(ignore_unknown_options=True))
-@click.option("-H",
- "-host",
- "--host",
- type=str,
- default=None,
- help="the list of hostnames to launch in the format ,")
+@click.command(
+ help="Launch distributed training on a single node or multiple nodes",
+ context_settings=dict(ignore_unknown_options=True),
+)
+@click.option(
+ "-H",
+ "-host",
+ "--host",
+ type=str,
+ default=None,
+ help="the list of hostnames to launch in the format ,",
+)
@click.option(
"--hostfile",
type=str,
default=None,
- help="Hostfile path that defines the device pool available to the job, each line in the file is a hostname")
-@click.option("--include",
- type=str,
- default=None,
- help="Specify computing devices to use during execution. String format is ,,"
- " only effective when used with --hostfile.")
+ help="Hostfile path that defines the device pool available to the job, each line in the file is a hostname",
+)
+@click.option(
+ "--include",
+ type=str,
+ default=None,
+ help="Specify computing devices to use during execution. String format is ,,"
+ " only effective when used with --hostfile.",
+)
@click.option(
"--exclude",
type=str,
default=None,
- help=
- "Specify computing devices to NOT use during execution. Mutually exclusive with --include. Formatting is the same as --include,"
- " only effective when used with --hostfile.")
-@click.option("--num_nodes",
- type=int,
- default=-1,
- help="Total number of worker nodes to use, only effective when used with --hostfile.")
+ help="Specify computing devices to NOT use during execution. Mutually exclusive with --include. Formatting is the same as --include,"
+ " only effective when used with --hostfile.",
+)
+@click.option(
+ "--num_nodes",
+ type=int,
+ default=-1,
+ help="Total number of worker nodes to use, only effective when used with --hostfile.",
+)
@click.option("--nproc_per_node", type=int, default=None, help="Number of GPUs to use on each node.")
-@click.option("--master_port",
- type=int,
- default=29500,
- help="(optional) Port used by PyTorch distributed for communication during distributed training.")
-@click.option("--master_addr",
- type=str,
- default="127.0.0.1",
- help="(optional) IP address of node 0, will be inferred via 'hostname -I' if not specified.")
+@click.option(
+ "--master_port",
+ type=int,
+ default=29500,
+ help="(optional) Port used by PyTorch distributed for communication during distributed training.",
+)
+@click.option(
+ "--master_addr",
+ type=str,
+ default="127.0.0.1",
+ help="(optional) IP address of node 0, will be inferred via 'hostname -I' if not specified.",
+)
@click.option(
"--extra_launch_args",
type=str,
default=None,
- help=
- "Set additional torch distributed launcher arguments such as --standalone. The format is --extra_launch_args arg1=1,arg2=2. "
- "This will be converted to --arg1=1 --arg2=2 during execution")
+ help="Set additional torch distributed launcher arguments such as --standalone. The format is --extra_launch_args arg1=1,arg2=2. "
+ "This will be converted to --arg1=1 --arg2=2 during execution",
+)
@click.option("--ssh-port", type=int, default=None, help="(optional) the port used for ssh connection")
@click.argument("user_script", type=str)
-@click.argument('user_args', nargs=-1)
-def run(host: str, hostfile: str, num_nodes: int, nproc_per_node: int, include: str, exclude: str, master_addr: str,
- master_port: int, extra_launch_args: str, ssh_port: int, user_script: str, user_args: str) -> None:
+@click.argument("user_args", nargs=-1)
+def run(
+ host: str,
+ hostfile: str,
+ num_nodes: int,
+ nproc_per_node: int,
+ include: str,
+ exclude: str,
+ master_addr: str,
+ master_port: int,
+ extra_launch_args: str,
+ ssh_port: int,
+ user_script: str,
+ user_args: str,
+) -> None:
"""
To launch multiple processes on a single node or multiple nodes via command line.
@@ -77,8 +102,8 @@ def run(host: str, hostfile: str, num_nodes: int, nproc_per_node: int, include:
# run with hostfile excluding the hosts selected
colossalai run --hostfile --master_addr host1 --exclude host2 --nprocs_per_node 4 train.py
"""
- if not user_script.endswith('.py'):
- click.echo(f'Error: invalid Python file {user_script}. Did you use a wrong option? Try colossalai run --help')
+ if not user_script.endswith(".py"):
+ click.echo(f"Error: invalid Python file {user_script}. Did you use a wrong option? Try colossalai run --help")
exit()
args_dict = locals()
diff --git a/colossalai/cli/launcher/hostinfo.py b/colossalai/cli/launcher/hostinfo.py
index 2a6a111e4d72..684f64f59d28 100644
--- a/colossalai/cli/launcher/hostinfo.py
+++ b/colossalai/cli/launcher/hostinfo.py
@@ -1,5 +1,4 @@
import socket
-from typing import List
class HostInfo:
@@ -34,7 +33,7 @@ def is_host_localhost(hostname: str, port: str = None) -> None:
"""
if port is None:
- port = 22 # no port specified, lets just use the ssh port
+ port = 22 # no port specified, lets just use the ssh port
# socket.getfqdn("127.0.0.1") does not return localhost
# on some users' machines
@@ -50,7 +49,7 @@ def is_host_localhost(hostname: str, port: str = None) -> None:
return localaddrs == targetaddrs
def __str__(self):
- return f'hostname: {self.hostname}, port: {self.port}'
+ return f"hostname: {self.hostname}, port: {self.port}"
def __repr__(self):
return self.__str__()
diff --git a/colossalai/cli/launcher/multinode_runner.py b/colossalai/cli/launcher/multinode_runner.py
index 85b241e96292..99c4db406844 100644
--- a/colossalai/cli/launcher/multinode_runner.py
+++ b/colossalai/cli/launcher/multinode_runner.py
@@ -7,8 +7,13 @@
from .hostinfo import HostInfo, HostInfoList
-def run_on_host(hostinfo: HostInfo, workdir: str, recv_conn: mp_connection.Connection,
- send_conn: mp_connection.Connection, env: dict) -> None:
+def run_on_host(
+ hostinfo: HostInfo,
+ workdir: str,
+ recv_conn: mp_connection.Connection,
+ send_conn: mp_connection.Connection,
+ env: dict,
+) -> None:
"""
Use fabric connection to execute command on local or remote hosts.
@@ -22,14 +27,14 @@ def run_on_host(hostinfo: HostInfo, workdir: str, recv_conn: mp_connection.Conne
fab_conn = fabric.Connection(hostinfo.hostname, port=hostinfo.port)
finish = False
- env_msg = ' '.join([f'{k}=\"{v}\"' for k, v in env.items()])
+ env_msg = " ".join([f'{k}="{v}"' for k, v in env.items()])
# keep listening until exit
while not finish:
# receive cmd
cmds = recv_conn.recv()
- if cmds == 'exit':
+ if cmds == "exit":
# exit from the loop
finish = True
break
@@ -46,12 +51,12 @@ def run_on_host(hostinfo: HostInfo, workdir: str, recv_conn: mp_connection.Conne
else:
# execute on the remote machine
fab_conn.run(cmds, hide=False)
- send_conn.send('success')
+ send_conn.send("success")
except Exception as e:
click.echo(
f"Error: failed to run {cmds} on {hostinfo.hostname}, is localhost: {hostinfo.is_local_host}, exception: {e}"
)
- send_conn.send('failure')
+ send_conn.send("failure")
# shutdown
send_conn.send("finish")
@@ -96,8 +101,7 @@ def send(self, hostinfo: HostInfo, cmd: str) -> None:
cmd (str): the command to execute
"""
- assert hostinfo.hostname in self.master_send_conns, \
- f'{hostinfo} is not found in the current connections'
+ assert hostinfo.hostname in self.master_send_conns, f"{hostinfo} is not found in the current connections"
conn = self.master_send_conns[hostinfo.hostname]
conn.send(cmd)
@@ -107,7 +111,7 @@ def stop_all(self) -> None:
"""
for hostname, conn in self.master_send_conns.items():
- conn.send('exit')
+ conn.send("exit")
def recv_from_all(self) -> dict:
"""
diff --git a/colossalai/cli/launcher/run.py b/colossalai/cli/launcher/run.py
index d2d02811ac9d..88f70f02ec27 100644
--- a/colossalai/cli/launcher/run.py
+++ b/colossalai/cli/launcher/run.py
@@ -12,7 +12,7 @@
from .multinode_runner import MultiNodeRunner
# Constants that define our syntax
-NODE_SEP = ','
+NODE_SEP = ","
def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList:
@@ -34,12 +34,12 @@ def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList:
click.echo(f"Error: Unable to find the hostfile, no such file: {hostfile_path}")
exit()
- with open(hostfile_path, 'r') as fd:
+ with open(hostfile_path, "r") as fd:
device_pool = HostInfoList()
for line in fd.readlines():
line = line.strip()
- if line == '':
+ if line == "":
# skip empty lines
continue
@@ -56,7 +56,7 @@ def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList:
def parse_device_filter(device_pool: HostInfoList, include_str=None, exclude_str=None) -> HostInfoList:
- '''Parse an inclusion or exclusion string and filter a hostfile dictionary.
+ """Parse an inclusion or exclusion string and filter a hostfile dictionary.
Examples:
include_str="worker-0,worker-1" will execute jobs only on worker-0 and worker-1.
@@ -69,7 +69,7 @@ def parse_device_filter(device_pool: HostInfoList, include_str=None, exclude_str
Returns:
filtered_hosts (HostInfoList): filtered hosts after inclusion/exclusion
- '''
+ """
# Ensure include/exclude are mutually exclusive
if include_str and exclude_str:
@@ -136,16 +136,16 @@ def _arg_dict_to_list(arg_dict):
for k, v in arg_dict.items():
if v:
- ret.append(f'--{k}={v}')
+ ret.append(f"--{k}={v}")
else:
- ret.append(f'--{k}')
+ ret.append(f"--{k}")
return ret
if extra_launch_args:
extra_launch_args_dict = dict()
- for arg in extra_launch_args.split(','):
- if '=' in arg:
- k, v = arg.split('=')
+ for arg in extra_launch_args.split(","):
+ if "=" in arg:
+ k, v = arg.split("=")
extra_launch_args_dict[k] = v
else:
extra_launch_args_dict[arg] = None
@@ -156,11 +156,17 @@ def _arg_dict_to_list(arg_dict):
torch_version = version.parse(torch.__version__)
assert torch_version.major >= 1
- if torch_version.minor < 9:
+ if torch_version.major == 1 and torch_version.minor < 9:
+ # torch distributed launch cmd with torch < 1.9
cmd = [
- sys.executable, "-m", "torch.distributed.launch", f"--nproc_per_node={nproc_per_node}",
- f"--master_addr={master_addr}", f"--master_port={master_port}", f"--nnodes={num_nodes}",
- f"--node_rank={node_rank}"
+ sys.executable,
+ "-m",
+ "torch.distributed.launch",
+ f"--nproc_per_node={nproc_per_node}",
+ f"--master_addr={master_addr}",
+ f"--master_port={master_port}",
+ f"--nnodes={num_nodes}",
+ f"--node_rank={node_rank}",
]
else:
# extra launch args for torch distributed launcher with torch >= 1.9
@@ -172,19 +178,28 @@ def _arg_dict_to_list(arg_dict):
value = extra_launch_args.pop(key)
default_torchrun_rdzv_args[key] = value
- if torch_version.minor < 10:
+ if torch_version.major == 1 and torch_version.minor == 9:
+ # torch distributed launch cmd with torch == 1.9
cmd = [
- sys.executable, "-m", "torch.distributed.run", f"--nproc_per_node={nproc_per_node}",
- f"--nnodes={num_nodes}", f"--node_rank={node_rank}"
+ sys.executable,
+ "-m",
+ "torch.distributed.run",
+ f"--nproc_per_node={nproc_per_node}",
+ f"--nnodes={num_nodes}",
+ f"--node_rank={node_rank}",
]
else:
+ # torch distributed launch cmd with torch > 1.9
cmd = [
- "torchrun", f"--nproc_per_node={nproc_per_node}", f"--nnodes={num_nodes}", f"--node_rank={node_rank}"
+ "torchrun",
+ f"--nproc_per_node={nproc_per_node}",
+ f"--nnodes={num_nodes}",
+ f"--node_rank={node_rank}",
]
cmd += _arg_dict_to_list(default_torchrun_rdzv_args)
cmd += _arg_dict_to_list(extra_launch_args) + [user_script] + user_args
- cmd = ' '.join(cmd)
+ cmd = " ".join(cmd)
return cmd
@@ -248,18 +263,18 @@ def launch_multi_processes(args: Config) -> None:
# run on local node if not hosts or hostfile is given
# add local node to host info list
active_device_pool = HostInfoList()
- localhost_info = HostInfo(hostname='127.0.0.1', port=args.ssh_port)
+ localhost_info = HostInfo(hostname="127.0.0.1", port=args.ssh_port)
active_device_pool.append(localhost_info)
# launch distributed processes
runner = MultiNodeRunner()
- curr_path = os.path.abspath('.')
+ curr_path = os.path.abspath(".")
# collect current path env
env = dict()
for k, v in os.environ.items():
# do not support multi-line env var
- if v and '\n' not in v:
+ if v and "\n" not in v:
env[k] = v
# establish remote connection
@@ -271,14 +286,16 @@ def launch_multi_processes(args: Config) -> None:
# execute distributed launching command
for node_id, hostinfo in enumerate(active_device_pool):
- cmd = get_launch_command(master_addr=args.master_addr,
- master_port=args.master_port,
- nproc_per_node=args.nproc_per_node,
- user_script=args.user_script,
- user_args=args.user_args,
- node_rank=node_id,
- num_nodes=len(active_device_pool),
- extra_launch_args=args.extra_launch_args)
+ cmd = get_launch_command(
+ master_addr=args.master_addr,
+ master_port=args.master_port,
+ nproc_per_node=args.nproc_per_node,
+ user_script=args.user_script,
+ user_args=args.user_args,
+ node_rank=node_id,
+ num_nodes=len(active_device_pool),
+ extra_launch_args=args.extra_launch_args,
+ )
runner.send(hostinfo=hostinfo, cmd=cmd)
# start training
diff --git a/colossalai/cluster/__init__.py b/colossalai/cluster/__init__.py
index 44f571ca2501..b8176feb647b 100644
--- a/colossalai/cluster/__init__.py
+++ b/colossalai/cluster/__init__.py
@@ -3,4 +3,4 @@
from .process_group_manager import ProcessGroupManager
from .process_group_mesh import ProcessGroupMesh
-__all__ = ['DistCoordinator', 'ProcessGroupManager', 'DeviceMeshManager', 'ProcessGroupMesh']
+__all__ = ["DistCoordinator", "ProcessGroupManager", "DeviceMeshManager", "ProcessGroupMesh"]
diff --git a/colossalai/cluster/device_mesh_manager.py b/colossalai/cluster/device_mesh_manager.py
index 8754baa19792..e35aca5f4d7e 100644
--- a/colossalai/cluster/device_mesh_manager.py
+++ b/colossalai/cluster/device_mesh_manager.py
@@ -10,13 +10,14 @@
@dataclass
class DeviceMeshInfo:
- '''
+ """
This class is used to store the information used to initialize the device mesh.
Args:
physical_ids (List[int]): The physical ids of the current booster. For example, if we have the last 4 GPUs on a 8-devices cluster, then the physical ids should be [4, 5, 6, 7].
mesh_shapes (List[Union[torch.Size, List[int], Tuple[int]]]): The shape of the mesh. For example, if we have 4 GPUs and we want to use 2D mesh with mesh shape [2, 2], then the mesh shape should be [2, 2].
- '''
+ """
+
physical_ids: List[int]
mesh_shape: Union[torch.Size, List[int], Tuple[int]] = None
@@ -24,16 +25,18 @@ def __post_init__(self):
if self.mesh_shape is not None:
world_size = len(self.physical_ids)
mesh_shape_numel = torch.Size(self.mesh_shape).numel()
- assert world_size == mesh_shape_numel, f'the numel of mesh_shape should be equal to world size, but got {world_size} != {mesh_shape_numel}'
+ assert (
+ world_size == mesh_shape_numel
+ ), f"the numel of mesh_shape should be equal to world size, but got {world_size} != {mesh_shape_numel}"
def initialize_device_mesh(device_mesh_info: DeviceMeshInfo):
- '''
+ """
This method is used to initialize the device mesh.
Args:
device_mesh_info (DeviceMeshInfo): The information used to initialize device mesh.
- '''
+ """
# parse the device mesh info
physical_devices = device_mesh_info.physical_ids
physical_mesh = torch.tensor(physical_devices)
@@ -67,13 +70,13 @@ def create_device_mesh(self, name, device_mesh_info: DeviceMeshInfo) -> DeviceMe
Args:
name (str): name of the device mesh
device_mesh_info (DeviceMeshInfo): the information used to initialize the device mesh
- """
+ """
if name not in self.device_mesh_store:
device_mesh = initialize_device_mesh(device_mesh_info)
self.device_mesh_store[name] = device_mesh
return device_mesh
else:
- raise ValueError(f'Device mesh {name} already exists.')
+ raise ValueError(f"Device mesh {name} already exists.")
def get(self, name: str) -> DeviceMesh:
"""
@@ -88,7 +91,7 @@ def get(self, name: str) -> DeviceMesh:
if name in self.device_mesh_store:
return self.device_mesh_store[name]
else:
- raise ValueError(f'Device mesh {name} does not exist.')
+ raise ValueError(f"Device mesh {name} does not exist.")
def destroy(self, name: str) -> None:
"""
@@ -103,7 +106,7 @@ def destroy(self, name: str) -> None:
dist.destroy_process_group(pg)
del self.device_mesh_store[name]
else:
- raise ValueError(f'Device mesh {name} does not exist.')
+ raise ValueError(f"Device mesh {name} does not exist.")
def destroy_all(self):
"""
diff --git a/colossalai/cluster/dist_coordinator.py b/colossalai/cluster/dist_coordinator.py
index 3ee364ec3364..98191747e5b3 100644
--- a/colossalai/cluster/dist_coordinator.py
+++ b/colossalai/cluster/dist_coordinator.py
@@ -20,14 +20,16 @@ class in the whole program.
- master: the process with rank 0
- node master: the process with local rank 0 on the current node
- Example:
- >>> from colossalai.cluster.dist_coordinator import DistCoordinator
- >>> coordinator = DistCoordinator()
- >>>
- >>> if coordinator.is_master():
- >>> do_something()
- >>>
- >>> coordinator.print_on_master('hello world')
+
+ ```python
+ from colossalai.cluster.dist_coordinator import DistCoordinator
+ coordinator = DistCoordinator()
+
+ if coordinator.is_master():
+ do_something()
+
+ coordinator.print_on_master('hello world')
+ ```
Attributes:
rank (int): the rank of the current process
@@ -36,12 +38,13 @@ class in the whole program.
"""
def __init__(self):
- assert dist.is_initialized(
- ), 'Distributed is not initialized. Please call `torch.distributed.init_process_group` or `colossalai.launch` first.'
+ assert (
+ dist.is_initialized()
+ ), "Distributed is not initialized. Please call `torch.distributed.init_process_group` or `colossalai.launch` first."
self._rank = dist.get_rank()
self._world_size = dist.get_world_size()
# this is often passed by launchers such as torchrun
- self._local_rank = os.environ.get('LOCAL_RANK', -1)
+ self._local_rank = os.environ.get("LOCAL_RANK", -1)
@property
def rank(self) -> int:
@@ -59,7 +62,9 @@ def _assert_local_rank_set(self):
"""
Assert that the local rank is set. This is often passed by launchers such as torchrun.
"""
- assert self.local_rank >= 0, 'The environment variable LOCAL_RANK is not set, thus the coordinator is not aware of the local rank of the current process.'
+ assert (
+ self.local_rank >= 0
+ ), "The environment variable LOCAL_RANK is not set, thus the coordinator is not aware of the local rank of the current process."
def is_master(self, process_group: ProcessGroup = None) -> bool:
"""
@@ -128,11 +133,13 @@ def priority_execution(self, executor_rank: int = 0, process_group: ProcessGroup
other processes in the same process group. This is often useful when downloading is required
as we only want to download in one process to prevent file corruption.
- Example:
- >>> from colossalai.cluster import DistCoordinator
- >>> dist_coordinator = DistCoordinator()
- >>> with dist_coordinator.priority_execution():
- >>> dataset = CIFAR10(root='./data', download=True)
+
+ ```python
+ from colossalai.cluster import DistCoordinator
+ dist_coordinator = DistCoordinator()
+ with dist_coordinator.priority_execution():
+ dataset = CIFAR10(root='./data', download=True)
+ ```
Args:
executor_rank (int): the process rank to execute without blocking, all other processes will be blocked
@@ -171,19 +178,19 @@ def on_master_only(self, process_group: ProcessGroup = None):
"""
A function wrapper that only executes the wrapped function on the master process (rank 0).
- Example:
- >>> from colossalai.cluster import DistCoordinator
- >>> dist_coordinator = DistCoordinator()
- >>>
- >>> @dist_coordinator.on_master_only()
- >>> def print_on_master(msg):
- >>> print(msg)
+ ```python
+ from colossalai.cluster import DistCoordinator
+ dist_coordinator = DistCoordinator()
+
+ @dist_coordinator.on_master_only()
+ def print_on_master(msg):
+ print(msg)
+ ```
"""
is_master = self.is_master(process_group)
# define an inner function
def decorator(func):
-
@functools.wraps(func)
def wrapper(*args, **kwargs):
if is_master:
diff --git a/colossalai/cluster/process_group_manager.py b/colossalai/cluster/process_group_manager.py
index e52661846f3e..68106b503126 100644
--- a/colossalai/cluster/process_group_manager.py
+++ b/colossalai/cluster/process_group_manager.py
@@ -19,7 +19,7 @@ class ProcessGroupManager:
def __init__(self):
self.pg_store = dict()
- def create_process_group(self, name: str, ranks: List[int], backend: str = 'nccl') -> ProcessGroup:
+ def create_process_group(self, name: str, ranks: List[int], backend: str = "nccl") -> ProcessGroup:
"""
Get a process group by name. If the process group does not exist, it will be created.
@@ -36,7 +36,7 @@ def create_process_group(self, name: str, ranks: List[int], backend: str = 'nccl
self.pg_store[name] = pg
return pg
else:
- raise ValueError(f'Process group {name} already exists.')
+ raise ValueError(f"Process group {name} already exists.")
def get(self, name: str) -> ProcessGroup:
"""
@@ -51,7 +51,7 @@ def get(self, name: str) -> ProcessGroup:
if name in self.pg_store:
return self.pg_store[name]
else:
- raise ValueError(f'Process group {name} does not exist.')
+ raise ValueError(f"Process group {name} does not exist.")
def destroy(self, name: str) -> None:
"""
@@ -64,7 +64,7 @@ def destroy(self, name: str) -> None:
dist.destroy_process_group(self.pg_store[name])
del self.pg_store[name]
else:
- raise ValueError(f'Process group {name} does not exist.')
+ raise ValueError(f"Process group {name} does not exist.")
def destroy_all(self) -> None:
"""
diff --git a/colossalai/cluster/process_group_mesh.py b/colossalai/cluster/process_group_mesh.py
index 623160003767..3885bc962561 100644
--- a/colossalai/cluster/process_group_mesh.py
+++ b/colossalai/cluster/process_group_mesh.py
@@ -94,7 +94,7 @@ def unravel(rank: int, shape: Tuple[int, ...]) -> Tuple[int, ...]:
return np.unravel_index(rank, shape)
@staticmethod
- def ravel(coord: Tuple[int, ...], shape: Tuple[int, ...], mode: str = 'raise') -> int:
+ def ravel(coord: Tuple[int, ...], shape: Tuple[int, ...], mode: str = "raise") -> int:
"""Convert a coordinate to a rank.
mode: ['raise', 'wrap', 'clip'], see https://numpy.org/doc/stable/reference/generated/numpy.ravel_multi_index.html.
with wrap, index out of range would be wrapped around.
@@ -141,8 +141,9 @@ def get_ranks_in_group(self, group: ProcessGroup) -> List[int]:
return list(self._group_to_ranks[group])
@staticmethod
- def get_coords_along_axis(base_coord: Tuple[int, ...], axis: int,
- indices_at_axis: List[int]) -> List[Tuple[int, ...]]:
+ def get_coords_along_axis(
+ base_coord: Tuple[int, ...], axis: int, indices_at_axis: List[int]
+ ) -> List[Tuple[int, ...]]:
"""Get coordinates along the given axis.
Args:
@@ -155,13 +156,12 @@ def get_coords_along_axis(base_coord: Tuple[int, ...], axis: int,
"""
coords_in_group = []
for idx in indices_at_axis:
- coords_in_group.append(base_coord[:axis] + (idx,) + base_coord[axis + 1:])
+ coords_in_group.append(base_coord[:axis] + (idx,) + base_coord[axis + 1 :])
return coords_in_group
- def create_group_along_axis(self,
- axis: int,
- indices_at_axis: Optional[List[int]] = None,
- backend: Optional[str] = None) -> ProcessGroup:
+ def create_group_along_axis(
+ self, axis: int, indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None
+ ) -> ProcessGroup:
"""Create all process groups along the given axis, and return the one which the current process belongs to.
Args:
@@ -186,10 +186,9 @@ def create_group_along_axis(self,
target_group = group
return target_group
- def get_group_along_axis(self,
- axis: int,
- indices_at_axis: Optional[List[int]] = None,
- backend: Optional[str] = None) -> ProcessGroup:
+ def get_group_along_axis(
+ self, axis: int, indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None
+ ) -> ProcessGroup:
"""Get the process group along the given axis which the current process belongs to. If the process group doesn't exist, it will be created.
Args:
diff --git a/colossalai/constants.py b/colossalai/constants.py
deleted file mode 100644
index 6cf9085f9fbb..000000000000
--- a/colossalai/constants.py
+++ /dev/null
@@ -1,32 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-ALLOWED_MODES = [None, '1d', '2d', '2.5d', '3d', 'sequence']
-TENSOR_PARALLEL_MODE = 'tensor_parallel_mode'
-
-# initializer
-INITIALIZER_MAPPING = {
- 'data': 'Initializer_Data',
- 'tensor': 'Initializer_Tensor',
- 'pipeline': 'Initializer_Pipeline',
- 'embedding': 'Initializer_Embedding',
- '1d': 'Initializer_1D',
- '2d': 'Initializer_2D',
- '2.5d': 'Initializer_2p5D',
- '3d': 'Initializer_3D',
- 'sequence': 'Initializer_Sequence',
- 'model': 'Initializer_Model',
- 'moe': 'Initializer_Moe'
-}
-
-# 3D parallelism groups
-INPUT_GROUP_3D = 'input_group_3d'
-WEIGHT_GROUP_3D = 'weight_group_3d'
-OUTPUT_GROUP_3D = 'output_group_3d'
-INPUT_X_WEIGHT_3D = 'input_x_weight_group_3d'
-OUTPUT_X_WEIGHT_3D = 'output_x_weight_group_3d'
-
-# Attributes of tensor parallel parameters
-IS_TENSOR_PARALLEL = 'is_tensor_parallel'
-NUM_PARTITIONS = 'num_partitions'
-TENSOR_PARALLEL_ATTRIBUTES = [IS_TENSOR_PARALLEL, NUM_PARTITIONS]
diff --git a/colossalai/context/__init__.py b/colossalai/context/__init__.py
index 50178b5fa850..ab57301bb910 100644
--- a/colossalai/context/__init__.py
+++ b/colossalai/context/__init__.py
@@ -1,6 +1,8 @@
from .config import Config, ConfigException
-from .parallel_context import ParallelContext
-from .parallel_mode import ParallelMode
-from .moe_context import MOE_CONTEXT
-from .process_group_initializer import *
-from .random import *
+
+# from .moe_context import MOE_CONTEXT
+
+__all__ = [
+ "Config",
+ "ConfigException",
+]
diff --git a/colossalai/context/config.py b/colossalai/context/config.py
index 8903707708df..05a2e4bf044a 100644
--- a/colossalai/context/config.py
+++ b/colossalai/context/config.py
@@ -5,6 +5,7 @@
import sys
from importlib.machinery import SourceFileLoader
from pathlib import Path
+
from colossalai.logging import get_dist_logger
@@ -41,7 +42,7 @@ def _add_item(self, key, value):
self.__setattr__(key, value)
def update(self, config):
- assert isinstance(config, (Config, dict)), 'can only update dictionary or Config objects.'
+ assert isinstance(config, (Config, dict)), "can only update dictionary or Config objects."
for k, v in config.items():
self._add_item(k, v)
return self
@@ -66,11 +67,11 @@ def from_file(filename: str):
elif isinstance(filename, Path):
filepath = filename.absolute()
- assert filepath.exists(), f'{filename} is not found, please check your configuration path'
+ assert filepath.exists(), f"{filename} is not found, please check your configuration path"
# check extension
extension = filepath.suffix
- assert extension == '.py', 'only .py files are supported'
+ assert extension == ".py", "only .py files are supported"
# import the config as module
remove_path = False
@@ -86,13 +87,13 @@ def from_file(filename: str):
config = Config()
for k, v in module.__dict__.items():
- if k.startswith('__') or inspect.ismodule(v) or inspect.isclass(v):
+ if k.startswith("__") or inspect.ismodule(v) or inspect.isclass(v):
continue
else:
config._add_item(k, v)
logger = get_dist_logger()
- logger.debug('variables which starts with __, is a module or class declaration are omitted in config file')
+ logger.debug("variables which starts with __, is a module or class declaration are omitted in config file")
# remove module
del sys.modules[module_name]
diff --git a/colossalai/context/moe_context.py b/colossalai/context/moe_context.py
index b41f4072a405..066dfc7222e1 100644
--- a/colossalai/context/moe_context.py
+++ b/colossalai/context/moe_context.py
@@ -3,21 +3,19 @@
import torch
import torch.distributed as dist
-from colossalai.context.parallel_mode import ParallelMode
from colossalai.context.singleton_meta import SingletonMeta
-from colossalai.tensor import ProcessGroup
+from colossalai.legacy.tensor import ProcessGroup
def _check_sanity():
- from colossalai.core import global_context as gpc
+ from colossalai.legacy.core import global_context as gpc
+
if gpc.tensor_parallel_size > 1 or gpc.pipeline_parallel_size > 1:
- raise NotImplementedError("Moe is not compatible with tensor or "
- "pipeline parallel at present.")
+ raise NotImplementedError("Moe is not compatible with tensor or " "pipeline parallel at present.")
class MoeParallelInfo:
- """Moe parallelism information, storing parallel sizes and groups.
- """
+ """Moe parallelism information, storing parallel sizes and groups."""
def __init__(self, ep_size: int, dp_size: int):
_check_sanity()
@@ -61,10 +59,12 @@ def setup(self, seed: int, use_kernel_optim: bool = True):
self.world_size = dist.get_world_size()
- from colossalai.core import global_context as gpc
- self.max_ep_size = gpc.config.get('max_ep_size', self.world_size)
- assert self.world_size % self.max_ep_size == 0, \
- "Maximum expert parallel size must be a factor of the number of GPUs"
+ from colossalai.legacy.core import global_context as gpc
+
+ self.max_ep_size = gpc.config.get("max_ep_size", self.world_size)
+ assert (
+ self.world_size % self.max_ep_size == 0
+ ), "Maximum expert parallel size must be a factor of the number of GPUs"
self.min_dp_size = self.world_size // self.max_ep_size
# Enabling kernel optimization may raise error in some cases
@@ -72,6 +72,7 @@ def setup(self, seed: int, use_kernel_optim: bool = True):
self.use_kernel_optim = use_kernel_optim
from .random import moe_set_seed
+
moe_set_seed(seed)
self.has_setup = True
@@ -89,11 +90,13 @@ def get_info(self, num_experts: int) -> Tuple[int, MoeParallelInfo]:
number of local experts, the MoeParallelInfo of the current ep_size
"""
- gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater
- lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less
+ gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater
+ lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less
- assert gt_flag or lt_flag, "Automatic experts placement dose not not support expert number" \
- " is not a multiple of ep size or vice versa."
+ assert gt_flag or lt_flag, (
+ "Automatic experts placement dose not not support expert number"
+ " is not a multiple of ep size or vice versa."
+ )
# If the number of experts is greater than maximum expert parallel size. a.k.a ep_size,
# there are multiple experts in each GPU and each GPU has different experts
diff --git a/colossalai/context/singleton_meta.py b/colossalai/context/singleton_meta.py
index 8ca335119d52..3088b0dffaac 100644
--- a/colossalai/context/singleton_meta.py
+++ b/colossalai/context/singleton_meta.py
@@ -16,6 +16,7 @@ def __call__(cls, *args, **kwargs):
instance = super().__call__(*args, **kwargs)
cls._instances[cls] = instance
else:
- assert len(args) == 0 and len(
- kwargs) == 0, f'{cls.__name__} is a singleton class and a instance has been created.'
+ assert (
+ len(args) == 0 and len(kwargs) == 0
+ ), f"{cls.__name__} is a singleton class and a instance has been created."
return cls._instances[cls]
diff --git a/colossalai/core.py b/colossalai/core.py
deleted file mode 100644
index 153247bbed9c..000000000000
--- a/colossalai/core.py
+++ /dev/null
@@ -1,6 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-from colossalai.context.parallel_context import global_context
-
-__all__ = ['global_context']
\ No newline at end of file
diff --git a/colossalai/device/__init__.py b/colossalai/device/__init__.py
index 689189998c3f..34a7d2526fda 100644
--- a/colossalai/device/__init__.py
+++ b/colossalai/device/__init__.py
@@ -1,4 +1,4 @@
from .alpha_beta_profiler import AlphaBetaProfiler
from .calc_pipeline_strategy import alpa_dp
-__all__ = ['AlphaBetaProfiler', 'alpa_dp']
+__all__ = ["AlphaBetaProfiler", "alpa_dp"]
diff --git a/colossalai/device/alpha_beta_profiler.py b/colossalai/device/alpha_beta_profiler.py
index f4e6cfffbcdf..88520b2a14d0 100644
--- a/colossalai/device/alpha_beta_profiler.py
+++ b/colossalai/device/alpha_beta_profiler.py
@@ -13,7 +13,7 @@
class AlphaBetaProfiler:
- '''
+ """
Profile alpha and beta value for a given device list.
Usage:
@@ -27,17 +27,19 @@ class AlphaBetaProfiler:
(1, 4): (1.9010603427886962e-05, 7.077968863788975e-11), (1, 5): (1.9807778298854827e-05, 6.928845708992215e-11), (4, 5): (1.8681809306144713e-05, 4.7522367291330524e-12),
(1, 0): (1.9641406834125518e-05, 4.74049549614719e-12), (4, 0): (1.9506998360157013e-05, 6.97421973297474e-11), (5, 0): (2.293858677148819e-05, 7.129930361393644e-11),
(4, 1): (1.9010603427886962e-05, 7.077968863788975e-11), (5, 1): (1.9807778298854827e-05, 6.928845708992215e-11), (5, 4): (1.8681809306144713e-05, 4.7522367291330524e-12)}
- '''
-
- def __init__(self,
- physical_devices: List[int],
- alpha_beta_dict: Dict[Tuple[int, int], Tuple[float, float]] = None,
- ctype: str = 'a',
- warmup: int = 5,
- repeat: int = 25,
- latency_iters: int = 5,
- homogeneous_tolerance: float = 0.1):
- '''
+ """
+
+ def __init__(
+ self,
+ physical_devices: List[int],
+ alpha_beta_dict: Dict[Tuple[int, int], Tuple[float, float]] = None,
+ ctype: str = "a",
+ warmup: int = 5,
+ repeat: int = 25,
+ latency_iters: int = 5,
+ homogeneous_tolerance: float = 0.1,
+ ):
+ """
Args:
physical_devices: A list of device id, each element inside it is the global rank of that device.
alpha_beta_dict: A dict which maps a process group to alpha-beta value pairs.
@@ -45,7 +47,7 @@ def __init__(self,
warmup: Number of warmup iterations.
repeat: Number of iterations to measure.
latency_iters: Number of iterations to measure latency.
- '''
+ """
self.physical_devices = physical_devices
self.ctype = ctype
self.world_size = len(physical_devices)
@@ -123,7 +125,7 @@ def _profile(self, process_group, pg_handler, nbytes):
return (None, None)
def profile_latency(self, process_group, pg_handler):
- '''
+ """
This function is used to profile the latency of the given process group with a series of bytes.
Args:
@@ -132,7 +134,7 @@ def profile_latency(self, process_group, pg_handler):
Returns:
latency: None if the latency is not measured, otherwise the median of the latency_list.
- '''
+ """
latency_list = []
for i in range(self.latency_iters):
nbytes = int(BYTE << i)
@@ -148,26 +150,26 @@ def profile_latency(self, process_group, pg_handler):
return latency
def profile_bandwidth(self, process_group, pg_handler, maxbytes=(1 * GB)):
- '''
+ """
This function is used to profile the bandwidth of the given process group.
Args:
process_group: A tuple of global rank of the process group.
pg_handler: The handler of the process group.
- '''
+ """
(_, bandwidth) = self._profile(process_group, pg_handler, maxbytes)
return bandwidth
def profile_ab(self):
- '''
+ """
This method is used to profiling the alpha and beta value for a given device list.
Returns:
alpha_beta_dict: A dict which maps process group to its alpha and beta value.
- '''
+ """
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = {}
rank = dist.get_rank()
- global_pg_handler = dist.new_group(self.physical_devices)
+ dist.new_group(self.physical_devices)
def get_max_nbytes(process_group: Tuple[int], pg_handler: dist.ProcessGroup):
assert rank in process_group
@@ -208,7 +210,7 @@ def get_max_nbytes(process_group: Tuple[int], pg_handler: dist.ProcessGroup):
return alpha_beta_dict
def search_best_logical_mesh(self):
- '''
+ """
This method is used to search the best logical mesh for the given device list.
The best logical mesh is searched in following steps:
@@ -232,19 +234,19 @@ def search_best_logical_mesh(self):
>>> best_logical_mesh = profiler.search_best_logical_mesh()
>>> print(best_logical_mesh)
[[0, 1], [2, 3]]
- '''
+ """
def _power_of_two(integer):
return integer & (integer - 1) == 0
def _detect_homogeneous_device(alpha_beta_dict):
- '''
+ """
This function is used to detect whether the devices in the alpha_beta_dict are homogeneous.
Note: we assume that the devices in the alpha_beta_dict are homogeneous if the beta value
of the devices are in range of [(1 - self.homogeneous_tolerance), (1 + self.homogeneous_tolerance)]
* base_beta.
- '''
+ """
homogeneous_device_dict: Dict[float, List[Tuple[int]]] = {}
for process_group, (_, beta) in alpha_beta_dict.items():
if homogeneous_device_dict is None:
@@ -254,7 +256,8 @@ def _detect_homogeneous_device(alpha_beta_dict):
match_beta = None
for beta_value in homogeneous_device_dict.keys():
if beta <= beta_value * (1 + self.homogeneous_tolerance) and beta >= beta_value * (
- 1 - self.homogeneous_tolerance):
+ 1 - self.homogeneous_tolerance
+ ):
match_beta = beta_value
break
@@ -267,9 +270,9 @@ def _detect_homogeneous_device(alpha_beta_dict):
return homogeneous_device_dict
def _check_contain_all_devices(homogeneous_group: List[Tuple[int]]):
- '''
+ """
This function is used to check whether the homogeneous_group contains all physical devices.
- '''
+ """
flatten_mesh = []
for process_group in homogeneous_group:
flatten_mesh.extend(process_group)
@@ -277,9 +280,9 @@ def _check_contain_all_devices(homogeneous_group: List[Tuple[int]]):
return len(non_duplicated_flatten_mesh) == len(self.physical_devices)
def _construct_largest_ring(homogeneous_group: List[Tuple[int]]):
- '''
+ """
This function is used to construct the largest ring in the homogeneous_group for each rank.
- '''
+ """
# Construct the ring
ring = []
ranks_in_ring = []
@@ -300,7 +303,9 @@ def _construct_largest_ring(homogeneous_group: List[Tuple[int]]):
check_rank = check_rank_list.pop()
for process_group in homogeneous_group:
if check_rank in process_group:
- rank_to_append = process_group[0] if process_group[1] == check_rank else process_group[1]
+ rank_to_append = (
+ process_group[0] if process_group[1] == check_rank else process_group[1]
+ )
if rank_to_append not in ring_for_rank:
stable_status = False
rank_to_check_list.append(rank_to_append)
@@ -314,7 +319,7 @@ def _construct_largest_ring(homogeneous_group: List[Tuple[int]]):
assert _power_of_two(self.world_size)
power_of_two = int(math.log2(self.world_size))
median = power_of_two // 2
- balanced_logical_mesh_shape = (2**median, 2**(power_of_two - median))
+ balanced_logical_mesh_shape = (2**median, 2 ** (power_of_two - median))
row_size, column_size = balanced_logical_mesh_shape[0], balanced_logical_mesh_shape[1]
balanced_logical_mesh = []
for row_index in range(row_size):
@@ -348,7 +353,7 @@ def _construct_largest_ring(homogeneous_group: List[Tuple[int]]):
return best_logical_mesh
def extract_alpha_beta_for_device_mesh(self):
- '''
+ """
Extract the mesh_alpha list and mesh_beta list based on the
best logical mesh, which will be used to initialize the device mesh.
@@ -360,7 +365,7 @@ def extract_alpha_beta_for_device_mesh(self):
[2.5917552411556242e-05, 0.00010312341153621673]
>>> print(mesh_beta)
[5.875573704655635e-11, 4.7361584445959614e-12]
- '''
+ """
best_logical_mesh = self.search_best_logical_mesh()
first_axis = [row[0] for row in best_logical_mesh]
diff --git a/colossalai/device/calc_pipeline_strategy.py b/colossalai/device/calc_pipeline_strategy.py
index 4ab72dfe60f0..72d432701ada 100644
--- a/colossalai/device/calc_pipeline_strategy.py
+++ b/colossalai/device/calc_pipeline_strategy.py
@@ -10,8 +10,10 @@ def get_submesh_choices(num_hosts, num_devices_per_host, mode="new"):
while i <= num_devices_per_host:
i *= 2
p += 1
- assert pow(2, p) == num_devices_per_host, ("Only supports the cases where num_devices_per_host is power of two, "
- f"while now num_devices_per_host = {num_devices_per_host}")
+ assert pow(2, p) == num_devices_per_host, (
+ "Only supports the cases where num_devices_per_host is power of two, "
+ f"while now num_devices_per_host = {num_devices_per_host}"
+ )
if mode == "alpa":
for i in range(p + 1):
submesh_choices.append((1, pow(2, i)))
@@ -24,18 +26,19 @@ def get_submesh_choices(num_hosts, num_devices_per_host, mode="new"):
return submesh_choices
-def alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, compute_cost, max_stage_cost,
- best_configs):
+def alpa_dp_impl(
+ num_layers, num_devices, num_microbatches, submesh_choices, compute_cost, max_stage_cost, best_configs
+):
"""Implementation of Alpa DP for pipeline strategy
- Paper reference: https://www.usenix.org/system/files/osdi22-zheng-lianmin.pdf
+ Paper reference: https://www.usenix.org/system/files/osdi22-zheng-lianmin.pdf
- Arguments:
- num_layers: K
- num_devices: N*M
- num_microbatches: B
- submesh_choices: List[(n_i,m_i)]
- compute_cost: t_intra
- """
+ Arguments:
+ num_layers: K
+ num_devices: N*M
+ num_microbatches: B
+ submesh_choices: List[(n_i,m_i)]
+ compute_cost: t_intra
+ """
# For f, layer ID start from 0
# f[#pipeline stages, layer id that is currently being considered, number of devices used]
f = np.full((num_layers + 1, num_layers + 1, num_devices + 1), np.inf, dtype=np.float32)
@@ -54,7 +57,7 @@ def alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, com
for i in range(num_layers, k, -1):
stage_cost = compute_cost[k, i, m]
new_cost = f[s - 1, k, d - n_submesh_devices] + stage_cost
- if (stage_cost <= max_stage_cost and new_cost < f[s, k, d]):
+ if stage_cost <= max_stage_cost and new_cost < f[s, k, d]:
f[s, k, d] = new_cost
f_stage_max[s, k, d] = max(stage_cost, f_stage_max[s - 1, i, d - n_submesh_devices])
f_argmin[s, k, d] = (i, m, best_configs[k, i, m])
@@ -75,34 +78,34 @@ def alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, com
res = []
while current_s > 0 and current_layer < num_layers and current_devices > 0:
- next_start_layer, submesh_choice, autosharding_choice = (f_argmin[current_s, current_layer, current_devices])
+ next_start_layer, submesh_choice, autosharding_choice = f_argmin[current_s, current_layer, current_devices]
assert next_start_layer != -1 and current_devices != -1
res.append(((current_layer, next_start_layer), submesh_choice, autosharding_choice))
current_s -= 1
current_layer = next_start_layer
current_devices -= np.prod(np.array(submesh_choices[submesh_choice]))
- assert (current_s == 0 and current_layer == num_layers and current_devices == 0)
+ assert current_s == 0 and current_layer == num_layers and current_devices == 0
return total_cost, res
-def alpa_dp(num_layers,
- num_devices,
- num_microbatches,
- submesh_choices,
- num_autosharding_configs,
- compute_cost,
- gap=1e-6):
+def alpa_dp(
+ num_layers, num_devices, num_microbatches, submesh_choices, num_autosharding_configs, compute_cost, gap=1e-6
+):
"""Alpa auto stage dynamic programming.
- Code reference: https://github.com/alpa-projects/alpa/blob/main/alpa/pipeline_parallel/stage_construction.py
+ Code reference: https://github.com/alpa-projects/alpa/blob/main/alpa/pipeline_parallel/stage_construction.py
Arguments:
submesh_choices: List[(int,int)]
num_autosharding_configs: Max number of t_intra(start_layer, end_layer, LogicalMesh)
compute_cost: np.array(num_layers,num_layers,num_submesh_choices,num_autosharding_configs)
"""
- assert np.shape(compute_cost) == (num_layers, num_layers, len(submesh_choices),
- num_autosharding_configs), "Cost shape wrong."
+ assert np.shape(compute_cost) == (
+ num_layers,
+ num_layers,
+ len(submesh_choices),
+ num_autosharding_configs,
+ ), "Cost shape wrong."
all_possible_stage_costs = np.sort(np.unique(compute_cost))
best_cost = np.inf
best_solution = None
@@ -117,8 +120,9 @@ def alpa_dp(num_layers,
break
if max_stage_cost - last_max_stage_cost < gap:
continue
- cost, solution = alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, best_compute_cost,
- max_stage_cost, best_configs)
+ cost, solution = alpa_dp_impl(
+ num_layers, num_devices, num_microbatches, submesh_choices, best_compute_cost, max_stage_cost, best_configs
+ )
if cost < best_cost:
best_cost = cost
best_solution = solution
diff --git a/colossalai/device/device_mesh.py b/colossalai/device/device_mesh.py
index f41af1161be1..72f199203a9d 100644
--- a/colossalai/device/device_mesh.py
+++ b/colossalai/device/device_mesh.py
@@ -40,14 +40,16 @@ class DeviceMesh:
_DIST_BACKEND = {"cuda": "nccl", "cpu": "gloo"}
- def __init__(self,
- physical_mesh_id: torch.Tensor,
- mesh_shape: torch.Size = None,
- logical_mesh_id: torch.Tensor = None,
- mesh_alpha: List[float] = None,
- mesh_beta: List[float] = None,
- init_process_group: bool = False,
- device: str = 'cuda'):
+ def __init__(
+ self,
+ physical_mesh_id: torch.Tensor,
+ mesh_shape: torch.Size = None,
+ logical_mesh_id: torch.Tensor = None,
+ mesh_alpha: List[float] = None,
+ mesh_beta: List[float] = None,
+ init_process_group: bool = False,
+ device: str = "cuda",
+ ):
# ============================
# Physical & Logical Mesh IDs
# ============================
@@ -57,9 +59,10 @@ def __init__(self,
# logical mesh ids can be obtained via two ways
# 1. provide physical mesh id and provide mesh shape
# 2. directly supply the logical mesh id
- assert mesh_shape is None or logical_mesh_id is None, \
- "Only one of mesh_shape and logical_mesh_id can be specified." \
+ assert mesh_shape is None or logical_mesh_id is None, (
+ "Only one of mesh_shape and logical_mesh_id can be specified."
"Logical mesh IDs are obtained from either mesh_shape + physical_mesh_id or directly from the user-supplied logical_mesh_id"
+ )
if logical_mesh_id is None:
self._mesh_shape = mesh_shape
@@ -71,12 +74,15 @@ def __init__(self,
# ensure two things:
# 1. logical and physical mesh IDs should contain the same elements
# 2. there is no duplicate IDs in each mesh, e.g. [2, 2] is not allowed
- assert torch.equal(torch.unique(self._physical_mesh_id), torch.unique(self.logical_mesh_id)), \
- "physical and logical mesh IDs should contain the same elements, please check if you have consistent physical_mesh_id and logical_mesh_id."
- assert torch.unique(self._physical_mesh_id).numel() == self._physical_mesh_id.numel(), \
- "Found duplicate IDs in the physical_mesh_id and this is not allowed, please check your physical_mesh_id again."
- assert torch.unique(self.logical_mesh_id).numel() == self.logical_mesh_id.numel(), \
- "Found duplicate IDs in the logical_mesh_id and this is not allowed, please check your logical_mesh_id again."
+ assert torch.equal(
+ torch.unique(self._physical_mesh_id), torch.unique(self.logical_mesh_id)
+ ), "physical and logical mesh IDs should contain the same elements, please check if you have consistent physical_mesh_id and logical_mesh_id."
+ assert (
+ torch.unique(self._physical_mesh_id).numel() == self._physical_mesh_id.numel()
+ ), "Found duplicate IDs in the physical_mesh_id and this is not allowed, please check your physical_mesh_id again."
+ assert (
+ torch.unique(self.logical_mesh_id).numel() == self.logical_mesh_id.numel()
+ ), "Found duplicate IDs in the logical_mesh_id and this is not allowed, please check your logical_mesh_id again."
# ===============================================
# coefficient for alpha-beta communication model
@@ -92,8 +98,9 @@ def __init__(self,
self.mesh_beta = tuple(mesh_beta)
# ensure the alpha and beta have the same shape
- assert len(self.mesh_alpha) == len(self.mesh_beta), \
- "mesh_alpha and mesh_beta should have the same length, please check your mesh_alpha and mesh_beta again."
+ assert len(self.mesh_alpha) == len(
+ self.mesh_beta
+ ), "mesh_alpha and mesh_beta should have the same length, please check your mesh_alpha and mesh_beta again."
# =========================
# Device for Process Group
@@ -109,8 +116,9 @@ def __init__(self,
# : [ , , , ...]
# }
self._global_to_local_rank_mapping = dict()
- self._init_global_to_logical_rank_mapping(mapping=self._global_to_local_rank_mapping,
- tensor=self.logical_mesh_id)
+ self._init_global_to_logical_rank_mapping(
+ mapping=self._global_to_local_rank_mapping, tensor=self.logical_mesh_id
+ )
# create process group
self._process_group_dict = {}
@@ -194,8 +202,9 @@ def _get_device_by_backend(process_group):
device_list = [_get_device_by_backend(pg) for pg in process_group]
# make sure all devices are the same
- assert all([device == device_list[0] for device in device_list]), \
- "All devices should be the same, please check your input process groups are created with the same distributed backend."
+ assert all(
+ [device == device_list[0] for device in device_list]
+ ), "All devices should be the same, please check your input process groups are created with the same distributed backend."
# create a fake physical mesh id
# as we only get the process group associated with the current process,
@@ -270,7 +279,7 @@ def __deepcopy__(self, memo) -> "DeviceMesh":
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
- if k != '_process_group_dict':
+ if k != "_process_group_dict":
setattr(result, k, __import__("copy").deepcopy(v, memo))
else:
# process group cannot be copied
@@ -278,10 +287,9 @@ def __deepcopy__(self, memo) -> "DeviceMesh":
setattr(result, k, v)
return result
- def _init_global_to_logical_rank_mapping(self,
- mapping: Dict,
- tensor: torch.Tensor,
- index_list: List[int] = []) -> Dict[int, List[int]]:
+ def _init_global_to_logical_rank_mapping(
+ self, mapping: Dict, tensor: torch.Tensor, index_list: List[int] = []
+ ) -> Dict[int, List[int]]:
"""
Build a global rank to local rank mapping for each process group in different axis in the logical device mesh.
@@ -311,15 +319,19 @@ def _init_global_to_logical_rank_mapping(self,
self._init_global_to_logical_rank_mapping(mapping, inner_tensor, index_list + [index])
def init_logical_process_group(self):
- '''
+ """
This method is used to initialize the logical process groups which will be used in communications
among logical device mesh.
Note: if init_process_group set to False, you have to call this method manually. Otherwise,
the communication related function, such as ShapeConsistencyManager.apply will raise errors.
- '''
+ """
# sanity check
- assert dist.is_initialized, "The torch.distributed should be initialized before calling init_logical_process_group"
- assert not self._is_initialized, "The logical process group has been initialized, do not call init_logical_process_group twice"
+ assert (
+ dist.is_initialized
+ ), "The torch.distributed should be initialized before calling init_logical_process_group"
+ assert (
+ not self._is_initialized
+ ), "The logical process group has been initialized, do not call init_logical_process_group twice"
# update the global rank of the current process
self._global_rank_of_current_process = dist.get_rank()
@@ -389,7 +401,7 @@ def global_rank_to_local_rank(self, rank: int, axis: int = None) -> Union[List[i
return local_ranks
def _collate_global_ranks_in_same_process_group(self, global_rank):
- '''
+ """
Give a global rank and return all global ranks involved in its associated process group in each axis.
Example:
@@ -414,7 +426,7 @@ def _collate_global_ranks_in_same_process_group(self, global_rank):
0: [0, 4, 8, 12],
1: [0, 1, 2, 3]
# }
- '''
+ """
# We have init the global rank to local rank by calling _init_global_to_logical_rank_mapping
# for self._global_to_local_rank_mapping
# the key is the global rank
@@ -437,7 +449,6 @@ def _collate_global_ranks_in_same_process_group(self, global_rank):
# in the same process group in the given axis
# the _local_rank refers to the local rank of the current process
for _local_rank in range(self.logical_mesh_id.shape[dim]):
-
# if this dimension is not initialized yet,
# initialize it with an empty array
if dim not in processes_in_the_same_process_group:
@@ -478,29 +489,37 @@ def flatten(self):
flatten_mesh_shape_size = len(self._mesh_shape)
flatten_mesh_shape = [self.num_devices]
- return DeviceMesh(self._physical_mesh_id,
- tuple(flatten_mesh_shape),
- mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1),
- mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1),
- init_process_group=self._init_process_group)
+ return DeviceMesh(
+ self._physical_mesh_id,
+ tuple(flatten_mesh_shape),
+ mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1),
+ mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1),
+ init_process_group=self._init_process_group,
+ )
def all_gather_cost(self, num_bytes, mesh_dim):
num_devices = self.logical_mesh_id.shape[mesh_dim]
- return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes +
- 0.1)
+ return self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes + 0.1
def all_reduce_cost(self, num_bytes, mesh_dim):
num_devices = self.logical_mesh_id.shape[mesh_dim]
- return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * 2 * (num_devices - 1) / num_devices * num_bytes +
- 0.01)
+ return (
+ self.mesh_alpha[mesh_dim]
+ + self.mesh_beta[mesh_dim] * 2 * (num_devices - 1) / num_devices * num_bytes
+ + 0.01
+ )
def reduce_scatter_cost(self, num_bytes, mesh_dim):
num_devices = self.logical_mesh_id.shape[mesh_dim]
- return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes +
- 0.001)
+ return (
+ self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes + 0.001
+ )
def all_to_all_cost(self, num_bytes, mesh_dim):
num_devices = self.logical_mesh_id.shape[mesh_dim]
penalty_factor = num_devices / 2.0
- return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] *
- (num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor + 0.001)
+ return (
+ self.mesh_alpha[mesh_dim]
+ + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor
+ + 0.001
+ )
diff --git a/colossalai/fx/_compatibility.py b/colossalai/fx/_compatibility.py
index 0444a4816273..4d40d5badfd0 100644
--- a/colossalai/fx/_compatibility.py
+++ b/colossalai/fx/_compatibility.py
@@ -2,16 +2,14 @@
import torch
-TORCH_MAJOR = int(torch.__version__.split('.')[0])
-TORCH_MINOR = int(torch.__version__.split('.')[1])
+TORCH_MAJOR = int(torch.__version__.split(".")[0])
+TORCH_MINOR = int(torch.__version__.split(".")[1])
if TORCH_MAJOR == 1 and TORCH_MINOR < 12:
META_COMPATIBILITY = False
elif TORCH_MAJOR == 1 and TORCH_MINOR == 12:
- from . import _meta_regist_12
META_COMPATIBILITY = True
elif TORCH_MAJOR == 1 and TORCH_MINOR == 13:
- from . import _meta_regist_13
META_COMPATIBILITY = True
elif TORCH_MAJOR == 2:
META_COMPATIBILITY = True
@@ -36,7 +34,7 @@ def decorator(func):
else:
def wrapper(*args, **kwargs):
- raise RuntimeError(f'Function `{func.__name__}` is not compatible with PyTorch {torch.__version__}')
+ raise RuntimeError(f"Function `{func.__name__}` is not compatible with PyTorch {torch.__version__}")
return wrapper
diff --git a/colossalai/fx/_meta_regist_12.py b/colossalai/fx/_meta_regist_12.py
index 52e8d63ae543..63f88682e85a 100644
--- a/colossalai/fx/_meta_regist_12.py
+++ b/colossalai/fx/_meta_regist_12.py
@@ -3,7 +3,7 @@
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
# for more meta_registrations
-from typing import Callable, List, Optional, Tuple, Union
+from typing import List, Optional, Union
import torch
from torch.utils._pytree import tree_map
@@ -16,13 +16,11 @@
def register_meta(op, register_dispatcher=True):
-
def wrapper(f):
-
def add_func(op):
meta_table[op] = f
if register_dispatcher:
- name = (op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__)
+ name = op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__
try:
meta_lib.impl(name, f)
except:
@@ -48,7 +46,6 @@ def meta_conv(
output_padding: List[int],
groups: int,
):
-
def _formula(ln: int, p: int, d: int, k: int, s: int) -> int:
"""
Formula to apply to calculate the length of some dimension of the output
@@ -125,7 +122,8 @@ def calc_conv_nd_return_shape(
kernel_size[i],
stride[i],
output_padding_list[i],
- ))
+ )
+ )
else:
ret_shape.append(_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i]))
return ret_shape
@@ -159,22 +157,42 @@ def pick_memory_format():
shape_out = calc_conv_nd_return_shape(dims, kernel_size, stride, padding, dilation)
out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out))
mem_fmt = pick_memory_format()
- out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
+ out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
return out
@register_meta(aten._convolution.default)
-def meta_conv_1(input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: List[int],
- padding: List[int], dilation: List[int], is_transposed: bool, output_padding: List[int], groups: int,
- *extra_args):
+def meta_conv_1(
+ input_tensor: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor,
+ stride: List[int],
+ padding: List[int],
+ dilation: List[int],
+ is_transposed: bool,
+ output_padding: List[int],
+ groups: int,
+ *extra_args,
+):
out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups)
return out
@register_meta(aten.convolution_backward.default)
-def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride,
- padding, dilation, transposed, output_padding, groups, output_mask):
- return torch.empty_like(input), torch.empty_like(weight), torch.empty((bias_sizes), device='meta')
+def meta_conv_backward(
+ grad_output: torch.Tensor,
+ input: torch.Tensor,
+ weight: torch.Tensor,
+ bias_sizes,
+ stride,
+ padding,
+ dilation,
+ transposed,
+ output_padding,
+ groups,
+ output_mask,
+):
+ return torch.empty_like(input), torch.empty_like(weight), torch.empty((bias_sizes), device="meta")
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp
@@ -208,7 +226,6 @@ def meta_cuda_rnn(
batch_sizes,
dropout_state,
):
-
is_input_packed = len(batch_sizes) != 0
if is_input_packed:
seq_length = len(batch_sizes)
@@ -224,8 +241,11 @@ def meta_cuda_rnn(
if is_input_packed:
out_shape = [batch_sizes_sum, out_size * num_directions]
else:
- out_shape = ([mini_batch, seq_length, out_size *
- num_directions] if batch_first else [seq_length, mini_batch, out_size * num_directions])
+ out_shape = (
+ [mini_batch, seq_length, out_size * num_directions]
+ if batch_first
+ else [seq_length, mini_batch, out_size * num_directions]
+ )
output = input.new_empty(out_shape)
cell_shape = [num_layers * num_directions, mini_batch, hidden_size]
@@ -242,18 +262,20 @@ def meta_cuda_rnn(
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
@register_meta(aten._cudnn_rnn_backward.default)
-def meta_cudnn_rnn_backward(input: torch.Tensor,
- weight: torch.Tensor,
- weight_stride0: int,
- hx: torch.Tensor,
- cx: Optional[torch.Tensor] = None,
- *args,
- **kwargs):
+def meta_cudnn_rnn_backward(
+ input: torch.Tensor,
+ weight: torch.Tensor,
+ weight_stride0: int,
+ hx: torch.Tensor,
+ cx: Optional[torch.Tensor] = None,
+ *args,
+ **kwargs,
+):
print(input, weight, hx, cx)
grad_input = torch.empty_like(input)
grad_weight = torch.empty_like(weight)
grad_hx = torch.empty_like(hx)
- grad_cx = torch.empty_like(cx) if cx is not None else torch.empty((), device='meta')
+ grad_cx = torch.empty_like(cx) if cx is not None else torch.empty((), device="meta")
return grad_input, grad_weight, grad_hx, grad_cx
@@ -298,15 +320,25 @@ def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, traini
n_input = input.size(1)
output = torch.empty_like(input)
- running_mean = torch.empty((n_input), device='meta')
- running_var = torch.empty((n_input), device='meta')
+ running_mean = torch.empty((n_input), device="meta")
+ running_var = torch.empty((n_input), device="meta")
return output, running_mean, running_var
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
@register_meta(aten.native_batch_norm_backward.default)
-def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, save_mean,
- save_invstd, train, eps, output_mask):
+def meta_bn_backward(
+ dY: torch.Tensor,
+ input: torch.Tensor,
+ weight: torch.Tensor,
+ running_mean,
+ running_var,
+ save_mean,
+ save_invstd,
+ train,
+ eps,
+ output_mask,
+):
dX = torch.empty_like(input)
dgamma = torch.empty_like(weight)
dbeta = torch.empty_like(weight)
@@ -319,9 +351,9 @@ def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var,
n_input = input.size(1)
output = torch.empty_like(input)
- running_mean = torch.empty((n_input), device='meta')
- running_var = torch.empty((n_input), device='meta')
- reserve = torch.empty((0), dtype=torch.uint8, device='meta')
+ running_mean = torch.empty((n_input), device="meta")
+ running_var = torch.empty((n_input), device="meta")
+ reserve = torch.empty((0), dtype=torch.uint8, device="meta")
return output, running_mean, running_var, reserve
@@ -330,8 +362,17 @@ def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var,
# in training mode (evaluation mode batchnorm has a different algorithm),
# which is why this doesn't accept a 'training' parameter.
@register_meta(aten.cudnn_batch_norm_backward.default)
-def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var,
- save_mean, save_invstd, eps, reserve):
+def meta_cudnn_bn_backward(
+ dY: torch.Tensor,
+ input: torch.Tensor,
+ weight: torch.Tensor,
+ running_mean,
+ running_var,
+ save_mean,
+ save_invstd,
+ eps,
+ reserve,
+):
dX = torch.empty_like(input)
dgamma = torch.empty_like(weight)
dbeta = torch.empty_like(weight)
@@ -345,15 +386,16 @@ def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
n_input = input.size(1)
output = torch.empty_like(input)
- running_mean = torch.empty((bs, n_input, 1), device='meta')
- running_var = torch.empty((bs, n_input, 1), device='meta')
+ running_mean = torch.empty((bs, n_input, 1), device="meta")
+ running_var = torch.empty((bs, n_input, 1), device="meta")
return output, running_mean, running_var
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
@register_meta(aten.native_layer_norm_backward.default)
-def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias,
- grad_input_mask):
+def meta_ln_backward(
+ dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias, grad_input_mask
+):
dX = torch.empty_like(input)
dgamma = torch.empty_like(weight)
dbeta = torch.empty_like(bias)
@@ -397,16 +439,19 @@ def meta_index_Tensor(self, indices):
result: List[Optional[torch.Tensor]] = []
for i, index in enumerate(indices):
if index is not None:
- assert index.dtype in [torch.long, torch.int8, torch.bool],\
- "tensors used as indices must be long, byte or bool tensors"
+ assert index.dtype in [
+ torch.long,
+ torch.int8,
+ torch.bool,
+ ], "tensors used as indices must be long, byte or bool tensors"
if index.dtype in [torch.int8, torch.bool]:
nonzero = index.nonzero()
k = len(result)
assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}"
for j in range(index.ndim):
- assert index.shape[j] == self.shape[
- k +
- j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
+ assert (
+ index.shape[j] == self.shape[k + j]
+ ), f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
result.append(nonzero.select(1, j))
else:
result.append(index)
@@ -482,12 +527,15 @@ def meta_index_Tensor(self, indices):
# ============================== Embedding =========================================
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp
@register_meta(aten.embedding_dense_backward.default)
-def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx,
- scale_grad_by_freq):
- return torch.empty((num_weights, grad_output.size(-1)),
- dtype=grad_output.dtype,
- device=grad_output.device,
- layout=grad_output.layout)
+def meta_embedding_dense_backward(
+ grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx, scale_grad_by_freq
+):
+ return torch.empty(
+ (num_weights, grad_output.size(-1)),
+ dtype=grad_output.dtype,
+ device=grad_output.device,
+ layout=grad_output.layout,
+ )
# ============================== Dropout ===========================================
diff --git a/colossalai/fx/codegen/activation_checkpoint_codegen.py b/colossalai/fx/codegen/activation_checkpoint_codegen.py
index 33b164800262..dfb5754d71c1 100644
--- a/colossalai/fx/codegen/activation_checkpoint_codegen.py
+++ b/colossalai/fx/codegen/activation_checkpoint_codegen.py
@@ -1,4 +1,4 @@
-from typing import Any, Callable, Dict, Iterable, List, Tuple
+from typing import Any, Dict, Iterable, List, Tuple
import torch
@@ -18,6 +18,7 @@
magic_methods,
)
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
+
CODEGEN_AVAILABLE = True
except:
from torch.fx.graph import (
@@ -32,12 +33,13 @@
magic_methods,
)
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
+
CODEGEN_AVAILABLE = False
if CODEGEN_AVAILABLE:
- __all__ = ['ActivationCheckpointCodeGen']
+ __all__ = ["ActivationCheckpointCodeGen"]
else:
- __all__ = ['python_code_with_activation_checkpoint']
+ __all__ = ["python_code_with_activation_checkpoint"]
def _gen_saved_tensors_hooks():
@@ -125,15 +127,14 @@ def _find_ckpt_regions(nodes: List[Node]):
Find the checkpoint regions given a list of consecutive nodes. The outputs will be list
of tuples, each tuple is in the form of (start_index, end_index).
"""
- ckpt_nodes = []
ckpt_regions = []
start = -1
end = -1
current_region = None
for idx, node in enumerate(nodes):
- if 'activation_checkpoint' in node.meta:
- act_ckpt_label = node.meta['activation_checkpoint']
+ if "activation_checkpoint" in node.meta:
+ act_ckpt_label = node.meta["activation_checkpoint"]
# this activation checkpoint label is not set yet
# meaning this is the first node of the activation ckpt region
@@ -150,7 +151,7 @@ def _find_ckpt_regions(nodes: List[Node]):
current_region = act_ckpt_label
start = idx
end = -1
- elif current_region is not None and not 'activation_checkpoint' in node.meta:
+ elif current_region is not None and not "activation_checkpoint" in node.meta:
# used to check the case below
# node ckpt states = [ckpt, ckpt, non-ckpt]
end = idx - 1
@@ -178,8 +179,8 @@ def _find_offload_regions(nodes: List[Node]):
current_region = None
for idx, node in enumerate(nodes):
- if 'activation_offload' in node.meta and isinstance(node.meta['activation_offload'], Iterable):
- act_offload_label = node.meta['activation_offload']
+ if "activation_offload" in node.meta and isinstance(node.meta["activation_offload"], Iterable):
+ act_offload_label = node.meta["activation_offload"]
if current_region == None:
current_region = act_offload_label
@@ -226,9 +227,9 @@ def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars, use_reen
"""
Generate the checkpoint function call code text
"""
- outputs = ', '.join(output_vars)
- inputs = ', '.join(input_vars)
- return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant={use_reentrant})'
+ outputs = ", ".join(output_vars)
+ inputs = ", ".join(input_vars)
+ return f"{outputs} = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant={use_reentrant})"
def _end_of_ckpt(node: Node, check_idx: int) -> bool:
@@ -240,9 +241,9 @@ def _end_of_ckpt(node: Node, check_idx: int) -> bool:
Returns:
bool
"""
- if 'activation_checkpoint' in node.meta:
- if isinstance(node.meta['activation_checkpoint'], list):
- return node.meta['activation_checkpoint'][check_idx] == None
+ if "activation_checkpoint" in node.meta:
+ if isinstance(node.meta["activation_checkpoint"], list):
+ return node.meta["activation_checkpoint"][check_idx] == None
else:
return False
else:
@@ -260,11 +261,11 @@ def _find_nested_ckpt_regions(nodes, check_idx=0):
current_region = None
for idx, node in enumerate(nodes):
- if 'activation_checkpoint' in node.meta:
- if isinstance(node.meta['activation_checkpoint'], int):
- act_ckpt_label = node.meta['activation_checkpoint']
+ if "activation_checkpoint" in node.meta:
+ if isinstance(node.meta["activation_checkpoint"], int):
+ act_ckpt_label = node.meta["activation_checkpoint"]
else:
- act_ckpt_label = node.meta['activation_checkpoint'][check_idx]
+ act_ckpt_label = node.meta["activation_checkpoint"][check_idx]
# this activation checkpoint label is not set yet
# meaning this is the first node of the activation ckpt region
@@ -298,13 +299,9 @@ def _find_nested_ckpt_regions(nodes, check_idx=0):
return ckpt_regions
-def emit_ckpt_func(body,
- ckpt_func,
- node_list: List[Node],
- emit_node_func,
- delete_unused_value_func,
- level=0,
- in_ckpt=False):
+def emit_ckpt_func(
+ body, ckpt_func, node_list: List[Node], emit_node_func, delete_unused_value_func, level=0, in_ckpt=False
+):
"""Emit ckpt function in nested way
Args:
body: forward code, in recursive calls, this part will be checkpoint
@@ -321,17 +318,17 @@ def emit_ckpt_func(body,
inputs, outputs = _find_input_and_output_nodes(node_list)
# if the current checkpoint function use int as label, using old generation method
- if isinstance(node_list[0].meta['activation_checkpoint'], int):
- label = node_list[0].meta['activation_checkpoint']
+ if isinstance(node_list[0].meta["activation_checkpoint"], int):
+ label = node_list[0].meta["activation_checkpoint"]
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
- ckpt_func.append(f'{ckpt_fn_def}\n')
+ ckpt_func.append(f"{ckpt_fn_def}\n")
for node in node_list:
emit_node_func(node, ckpt_func)
- ckpt_func[-1] = ' ' + ckpt_func[-1]
+ ckpt_func[-1] = " " + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func)
- ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
- activation_offload = node_list[0].meta.get('activation_offload', False)
+ ckpt_func.append(" " + _gen_ckpt_output(outputs) + "\n\n")
+ activation_offload = node_list[0].meta.get("activation_offload", False)
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False)
usage += "\n"
body.append(usage)
@@ -340,12 +337,12 @@ def emit_ckpt_func(body,
else:
# label given by each layer, e.g. if you are currently at level [0, 1, 1]
# the label will be '0_1_1'
- label = "_".join([str(idx) for idx in node_list[0].meta['activation_checkpoint'][:level + 1]])
+ label = "_".join([str(idx) for idx in node_list[0].meta["activation_checkpoint"][: level + 1]])
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
- ckpt_func.append(f'{ckpt_fn_def}\n')
+ ckpt_func.append(f"{ckpt_fn_def}\n")
# if there is more level to fetch
- if level + 1 < len(node_list[0].meta['activation_checkpoint']):
+ if level + 1 < len(node_list[0].meta["activation_checkpoint"]):
ckpt_regions = _find_nested_ckpt_regions(node_list, level + 1)
start_idx = [item[0] for item in ckpt_regions]
end_idx = [item[1] for item in ckpt_regions]
@@ -358,38 +355,45 @@ def emit_ckpt_func(body,
break
if node_idx in start_idx:
- ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1]
- emit_ckpt_func(ckpt_func, ckpt_func_buffer, ckpt_node_list, emit_node_func,
- delete_unused_value_func, level + 1, True)
+ ckpt_node_list = node_list[node_idx : end_idx[start_idx.index(node_idx)] + 1]
+ emit_ckpt_func(
+ ckpt_func,
+ ckpt_func_buffer,
+ ckpt_node_list,
+ emit_node_func,
+ delete_unused_value_func,
+ level + 1,
+ True,
+ )
node_idx += len(ckpt_node_list)
else:
node = node_list[node_idx]
emit_node_func(node, ckpt_func)
- ckpt_func[-1] = ' ' + ckpt_func[-1]
+ ckpt_func[-1] = " " + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func)
node_idx += 1
- ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
+ ckpt_func.append(" " + _gen_ckpt_output(outputs) + "\n\n")
ckpt_func += ckpt_func_buffer
- activation_offload = node_list[0].meta.get('activation_offload', False)
- usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n'
+ activation_offload = node_list[0].meta.get("activation_offload", False)
+ usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + "\n"
if in_ckpt:
- usage = ' ' + usage
+ usage = " " + usage
body.append(usage)
# last level
else:
for node in node_list:
emit_node_func(node, ckpt_func)
- ckpt_func[-1] = ' ' + ckpt_func[-1]
+ ckpt_func[-1] = " " + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func)
- ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
- activation_offload = node_list[0].meta.get('activation_offload', False)
- usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n'
+ ckpt_func.append(" " + _gen_ckpt_output(outputs) + "\n\n")
+ activation_offload = node_list[0].meta.get("activation_offload", False)
+ usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + "\n"
if in_ckpt:
- usage = ' ' + usage
+ usage = " " + usage
body.append(usage)
@@ -420,7 +424,7 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod
# find the input and output var names for each offload region
for idx, (start, end) in enumerate(offload_regions):
- offload_node_list = node_list[start:end + 1]
+ offload_node_list = node_list[start : end + 1]
inputs, outputs = _find_input_and_output_nodes(offload_node_list)
offload_inputs.append(inputs)
offload_outputs.append(outputs)
@@ -436,7 +440,7 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod
# process ckpt_regions
if node_idx in start_idx:
- ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1]
+ ckpt_node_list = node_list[node_idx : end_idx[start_idx.index(node_idx)] + 1]
emit_ckpt_func(body, ckpt_func, ckpt_node_list, emit_node_func, delete_unused_value_func)
node_idx += len(ckpt_node_list)
@@ -470,7 +474,7 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod
if within_offload_region:
emit_node_func(node, body)
- body[-1] = ' ' + body[-1]
+ body[-1] = " " + body[-1]
delete_unused_value_func(node, body)
else:
@@ -508,14 +512,14 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
# find the input and output var names for each region
for idx, (start, end) in enumerate(ckpt_regions):
- ckpt_node_list = node_list[start:end + 1]
+ ckpt_node_list = node_list[start : end + 1]
inputs, outputs = _find_input_and_output_nodes(ckpt_node_list)
input_vars.append(inputs)
output_vars.append(outputs)
# find the input and output var names for each offload region
for idx, (start, end) in enumerate(offload_regions):
- offload_node_list = node_list[start:end + 1]
+ offload_node_list = node_list[start : end + 1]
inputs, outputs = _find_input_and_output_nodes(offload_node_list)
offload_inputs.append(inputs)
offload_outputs.append(outputs)
@@ -527,7 +531,7 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
if idx in start_idx:
label = start_idx.index(idx)
ckpt_fn_def = _gen_ckpt_fn_def(label, input_vars[label])
- ckpt_func.append(f'{ckpt_fn_def}\n')
+ ckpt_func.append(f"{ckpt_fn_def}\n")
within_ckpt_region = True
if idx in offload_starts:
@@ -559,12 +563,12 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
# NOTE: currently we separate body and ckpt_func definition
if within_ckpt_region:
emit_node_func(node, ckpt_func)
- ckpt_func[-1] = ' ' + ckpt_func[-1]
+ ckpt_func[-1] = " " + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func)
elif within_offload_region:
emit_node_func(node, body)
- body[-1] = ' ' + body[-1]
+ body[-1] = " " + body[-1]
delete_unused_value_func(node, body)
else:
@@ -576,13 +580,13 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
# generate return statement
label = end_idx.index(idx)
return_statement = _gen_ckpt_output(output_vars[label])
- return_statement = f' {return_statement}\n\n'
+ return_statement = f" {return_statement}\n\n"
ckpt_func.append(return_statement)
# we need to check if the checkpoint need to offload the input
start_node_idx = start_idx[label]
- if 'activation_offload' in node_list[start_node_idx].meta:
- activation_offload = node_list[start_node_idx].meta['activation_offload']
+ if "activation_offload" in node_list[start_node_idx].meta:
+ activation_offload = node_list[start_node_idx].meta["activation_offload"]
else:
activation_offload = False
@@ -594,8 +598,8 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
if input_node.op != "placeholder":
non_leaf_input = 1
for user in input_node.users:
- if 'activation_checkpoint' in user.meta:
- if user.meta['activation_checkpoint'] == label:
+ if "activation_checkpoint" in user.meta:
+ if user.meta["activation_checkpoint"] == label:
if user.op == "call_module":
if hasattr(user.graph.owning_module.get_submodule(user.target), "inplace"):
use_reentrant = not user.graph.owning_module.get_submodule(user.target).inplace
@@ -610,7 +614,7 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
# generate checkpoint function call in a new line
usage = _gen_ckpt_usage(label, activation_offload, input_vars[label], output_vars[label], use_reentrant)
- usage += '\n'
+ usage += "\n"
body.append(usage)
within_ckpt_region = False
@@ -621,7 +625,6 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
if CODEGEN_AVAILABLE:
class ActivationCheckpointCodeGen(CodeGen):
-
def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode:
free_vars: List[str] = []
body: List[str] = []
@@ -629,7 +632,7 @@ def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> Py
wrapped_fns: Dict[str, None] = {}
# Wrap string in list to pass by reference
- maybe_return_annotation: List[str] = ['']
+ maybe_return_annotation: List[str] = [""]
def add_global(name_hint: str, obj: Any):
"""Add an obj to be tracked as a global.
@@ -637,7 +640,7 @@ def add_global(name_hint: str, obj: Any):
Graph, like functions or types.
Returns: the global name that should be used to reference 'obj' in generated source.
"""
- if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
+ if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
# HACK: workaround for how torch custom ops are registered. We
# can't import them like normal modules so they must retain their
# fully qualified name.
@@ -662,16 +665,16 @@ def add_global(name_hint: str, obj: Any):
def type_repr(o: Any):
if o == ():
# Empty tuple is used for empty tuple type annotation Tuple[()]
- return '()'
+ return "()"
typename = _type_repr(o)
- if hasattr(o, '__origin__'):
+ if hasattr(o, "__origin__"):
# This is a generic type, e.g. typing.List[torch.Tensor]
origin_type = _origin_type_map.get(o.__origin__, o.__origin__)
origin_typename = add_global(_type_repr(origin_type), origin_type)
- if hasattr(o, '__args__'):
+ if hasattr(o, "__args__"):
# Assign global names for each of the inner type variables.
args = [type_repr(arg) for arg in o.__args__]
@@ -690,19 +693,18 @@ def type_repr(o: Any):
return add_global(typename, o)
def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:
-
def _get_repr(arg):
# Handle NamedTuples (if it has `_fields`) via add_global.
- if isinstance(arg, tuple) and hasattr(arg, '_fields'):
+ if isinstance(arg, tuple) and hasattr(arg, "_fields"):
qualified_name = _get_qualified_name(type(arg))
global_name = add_global(qualified_name, type(arg))
return f"{global_name}{repr(tuple(arg))}"
return repr(arg)
- args_s = ', '.join(_get_repr(a) for a in args)
- kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items())
+ args_s = ", ".join(_get_repr(a) for a in args)
+ kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items())
if args_s and kwargs_s:
- return f'{args_s}, {kwargs_s}'
+ return f"{args_s}, {kwargs_s}"
return args_s or kwargs_s
# Run through reverse nodes and record the first instance of a use
@@ -728,90 +730,101 @@ def delete_unused_values(user: Node, body):
not used in the remainder of the code are freed and the memory usage
of the code is optimal.
"""
- if user.op == 'placeholder':
+ if user.op == "placeholder":
return
- if user.op == 'output':
- body.append('\n')
+ if user.op == "output":
+ body.append("\n")
return
nodes_to_delete = user_to_last_uses.get(user, [])
if len(nodes_to_delete):
- to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None'])
- body.append(f'; {to_delete_str}\n')
+ to_delete_str = " = ".join([repr(n) for n in nodes_to_delete] + ["None"])
+ body.append(f"; {to_delete_str}\n")
else:
- body.append('\n')
+ body.append("\n")
# NOTE: we add a variable to distinguish body and ckpt_func
def emit_node(node: Node, body):
- maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
- if node.op == 'placeholder':
+ maybe_type_annotation = "" if node.type is None else f" : {type_repr(node.type)}"
+ if node.op == "placeholder":
assert isinstance(node.target, str)
- maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}'
- free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}')
- raw_name = node.target.replace('*', '')
+ maybe_default_arg = "" if not node.args else f" = {repr(node.args[0])}"
+ free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}")
+ raw_name = node.target.replace("*", "")
if raw_name != repr(node):
- body.append(f'{repr(node)} = {raw_name}\n')
+ body.append(f"{repr(node)} = {raw_name}\n")
return
- elif node.op == 'call_method':
+ elif node.op == "call_method":
assert isinstance(node.target, str)
body.append(
- f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}'
- f'({_format_args(node.args[1:], node.kwargs)})')
+ f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}"
+ f"({_format_args(node.args[1:], node.kwargs)})"
+ )
return
- elif node.op == 'call_function':
+ elif node.op == "call_function":
assert callable(node.target)
# pretty print operators
- if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods:
+ if node.target.__module__ == "_operator" and node.target.__name__ in magic_methods:
assert isinstance(node.args, tuple)
- body.append(f'{repr(node)}{maybe_type_annotation} = '
- f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}')
+ body.append(
+ f"{repr(node)}{maybe_type_annotation} = "
+ f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}"
+ )
return
# pretty print inplace operators; required for jit.script to work properly
# not currently supported in normal FX graphs, but generated by torchdynamo
- if node.target.__module__ == '_operator' and node.target.__name__ in inplace_methods:
- body.append(f'{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; '
- f'{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}')
+ if node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods:
+ body.append(
+ f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; "
+ f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}"
+ )
return
qualified_name = _get_qualified_name(node.target)
global_name = add_global(qualified_name, node.target)
# special case for getattr: node.args could be 2-argument or 3-argument
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
- if global_name == 'getattr' and \
- isinstance(node.args, tuple) and \
- isinstance(node.args[1], str) and \
- node.args[1].isidentifier() and \
- len(node.args) == 2:
+ if (
+ global_name == "getattr"
+ and isinstance(node.args, tuple)
+ and isinstance(node.args[1], str)
+ and node.args[1].isidentifier()
+ and len(node.args) == 2
+ ):
body.append(
- f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}')
+ f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}"
+ )
return
body.append(
- f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})')
- if node.meta.get('is_wrapped', False):
+ f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})"
+ )
+ if node.meta.get("is_wrapped", False):
wrapped_fns.setdefault(global_name)
return
- elif node.op == 'call_module':
+ elif node.op == "call_module":
assert isinstance(node.target, str)
- body.append(f'{repr(node)}{maybe_type_annotation} = '
- f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})')
+ body.append(
+ f"{repr(node)}{maybe_type_annotation} = "
+ f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})"
+ )
return
- elif node.op == 'get_attr':
+ elif node.op == "get_attr":
assert isinstance(node.target, str)
- body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}')
+ body.append(f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}")
return
- elif node.op == 'output':
+ elif node.op == "output":
if node.type is not None:
maybe_return_annotation[0] = f" -> {type_repr(node.type)}"
body.append(self.generate_output(node.args[0]))
return
- raise NotImplementedError(f'node: {node.op} {node.target}')
+ raise NotImplementedError(f"node: {node.op} {node.target}")
# Modified for activation checkpointing
ckpt_func = []
# if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen
- if any(isinstance(node.meta.get('activation_checkpoint', None), Iterable) for node in nodes):
+ if any(isinstance(node.meta.get("activation_checkpoint", None), Iterable) for node in nodes):
emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
else:
emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
@@ -820,13 +833,13 @@ def emit_node(node: Node, body):
# If the Graph has no non-placeholder nodes, no lines for the body
# have been emitted. To continue to have valid Python code, emit a
# single pass statement
- body.append('pass\n')
+ body.append("pass\n")
if len(wrapped_fns) > 0:
- wrap_name = add_global('wrap', torch.fx.wrap)
- wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns])
+ wrap_name = add_global("wrap", torch.fx.wrap)
+ wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns])
else:
- wrap_stmts = ''
+ wrap_stmts = ""
if self._body_transformer:
body = self._body_transformer(body)
@@ -837,11 +850,11 @@ def emit_node(node: Node, body):
# as we need colossalai.utils.checkpoint, we need to import colossalai
# in forward function
prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
- prologue = ''.join(ckpt_func) + prologue
+ prologue = "".join(ckpt_func) + prologue
prologue = prologue
- code = ''.join(body)
- code = '\n'.join(' ' + line for line in code.split('\n'))
+ code = "".join(body)
+ code = "\n".join(" " + line for line in code.split("\n"))
fn_code = f"""
{wrap_stmts}
{prologue}
@@ -861,7 +874,7 @@ def python_code_with_activation_checkpoint(self, root_module: str, namespace: _N
wrapped_fns: Dict[str, None] = {}
# Wrap string in list to pass by reference
- maybe_return_annotation: List[str] = ['']
+ maybe_return_annotation: List[str] = [""]
def add_global(name_hint: str, obj: Any):
"""Add an obj to be tracked as a global.
@@ -869,7 +882,7 @@ def add_global(name_hint: str, obj: Any):
Graph, like functions or types.
Returns: the global name that should be used to reference 'obj' in generated source.
"""
- if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
+ if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
# HACK: workaround for how torch custom ops are registered. We
# can't import them like normal modules so they must retain their
# fully qualified name.
@@ -894,12 +907,12 @@ def add_global(name_hint: str, obj: Any):
def type_repr(o: Any):
if o == ():
# Empty tuple is used for empty tuple type annotation Tuple[()]
- return '()'
+ return "()"
typename = _type_repr(o)
# This is a generic type, e.g. typing.List[torch.Tensor]
- if hasattr(o, '__origin__'):
+ if hasattr(o, "__origin__"):
origin_type = _origin_type_map.get(o.__origin__, o.__origin__)
origin_typename = add_global(_type_repr(origin_type), origin_type)
@@ -934,84 +947,94 @@ def delete_unused_values(user: Node, body):
not used in the remainder of the code are freed and the memory usage
of the code is optimal.
"""
- if user.op == 'placeholder':
+ if user.op == "placeholder":
return
- if user.op == 'output':
- body.append('\n')
+ if user.op == "output":
+ body.append("\n")
return
nodes_to_delete = user_to_last_uses.get(user, [])
if len(nodes_to_delete):
- to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None'])
- body.append(f'; {to_delete_str}\n')
+ to_delete_str = " = ".join([repr(n) for n in nodes_to_delete] + ["None"])
+ body.append(f"; {to_delete_str}\n")
else:
- body.append('\n')
+ body.append("\n")
# NOTE: we add a variable to distinguish body and ckpt_func
def emit_node(node: Node, body):
- maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
- if node.op == 'placeholder':
+ maybe_type_annotation = "" if node.type is None else f" : {type_repr(node.type)}"
+ if node.op == "placeholder":
assert isinstance(node.target, str)
- maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}'
- free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}')
- raw_name = node.target.replace('*', '')
+ maybe_default_arg = "" if not node.args else f" = {repr(node.args[0])}"
+ free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}")
+ raw_name = node.target.replace("*", "")
if raw_name != repr(node):
- body.append(f'{repr(node)} = {raw_name}\n')
+ body.append(f"{repr(node)} = {raw_name}\n")
return
- elif node.op == 'call_method':
+ elif node.op == "call_method":
assert isinstance(node.target, str)
- body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}'
- f'({_format_args(node.args[1:], node.kwargs)})')
+ body.append(
+ f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}"
+ f"({_format_args(node.args[1:], node.kwargs)})"
+ )
return
- elif node.op == 'call_function':
+ elif node.op == "call_function":
assert callable(node.target)
# pretty print operators
- if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods:
+ if node.target.__module__ == "_operator" and node.target.__name__ in magic_methods:
assert isinstance(node.args, tuple)
- body.append(f'{repr(node)}{maybe_type_annotation} = '
- f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}')
+ body.append(
+ f"{repr(node)}{maybe_type_annotation} = "
+ f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}"
+ )
return
qualified_name = _get_qualified_name(node.target)
global_name = add_global(qualified_name, node.target)
# special case for getattr: node.args could be 2-argument or 3-argument
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
- if global_name == 'getattr' and \
- isinstance(node.args, tuple) and \
- isinstance(node.args[1], str) and \
- node.args[1].isidentifier() and \
- len(node.args) == 2:
+ if (
+ global_name == "getattr"
+ and isinstance(node.args, tuple)
+ and isinstance(node.args[1], str)
+ and node.args[1].isidentifier()
+ and len(node.args) == 2
+ ):
body.append(
- f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}')
+ f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}"
+ )
return
body.append(
- f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})')
- if node.meta.get('is_wrapped', False):
+ f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})"
+ )
+ if node.meta.get("is_wrapped", False):
wrapped_fns.setdefault(global_name)
return
- elif node.op == 'call_module':
+ elif node.op == "call_module":
assert isinstance(node.target, str)
- body.append(f'{repr(node)}{maybe_type_annotation} = '
- f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})')
+ body.append(
+ f"{repr(node)}{maybe_type_annotation} = "
+ f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})"
+ )
return
- elif node.op == 'get_attr':
+ elif node.op == "get_attr":
assert isinstance(node.target, str)
- body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}')
+ body.append(f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}")
return
- elif node.op == 'output':
+ elif node.op == "output":
if node.type is not None:
maybe_return_annotation[0] = f" -> {type_repr(node.type)}"
if self._pytree_info is None:
- body.append(f'return {repr(node.args[0])}')
+ body.append(f"return {repr(node.args[0])}")
else:
- body.append(f'return pytree.tree_unflatten({repr(node.args[0])}, self._out_spec)')
+ body.append(f"return pytree.tree_unflatten({repr(node.args[0])}, self._out_spec)")
return
- raise NotImplementedError(f'node: {node.op} {node.target}')
+ raise NotImplementedError(f"node: {node.op} {node.target}")
# Modified for activation checkpointing
ckpt_func = []
# if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen
- if any(isinstance(node.meta.get('activation_checkpoint', None), Iterable) for node in self.nodes):
+ if any(isinstance(node.meta.get("activation_checkpoint", None), Iterable) for node in self.nodes):
emit_code_with_nested_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
else:
emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
@@ -1020,33 +1043,34 @@ def emit_node(node: Node, body):
# If the Graph has no non-placeholder nodes, no lines for the body
# have been emitted. To continue to have valid Python code, emit a
# single pass statement
- body.append('pass\n')
+ body.append("pass\n")
if self._pytree_info is not None:
orig_args = self._pytree_info.orig_args
- has_orig_self = (orig_args[0] == 'self')
+ has_orig_self = orig_args[0] == "self"
if has_orig_self:
- free_vars.insert(0, 'self')
- if len(free_vars) > 0: # pytree has placeholders in it
+ free_vars.insert(0, "self")
+ if len(free_vars) > 0: # pytree has placeholders in it
body.insert(
0,
- f"{', '.join(free_vars)}, = fx_pytree.tree_flatten_spec([{', '.join(orig_args)}], self._in_spec)\n")
+ f"{', '.join(free_vars)}, = fx_pytree.tree_flatten_spec([{', '.join(orig_args)}], self._in_spec)\n",
+ )
else:
orig_args = free_vars
if len(wrapped_fns) > 0:
- wrap_name = add_global('wrap', torch.fx.wrap)
- wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns])
+ wrap_name = add_global("wrap", torch.fx.wrap)
+ wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns])
else:
- wrap_stmts = ''
+ wrap_stmts = ""
- ckpt_func = ''.join(ckpt_func)
+ ckpt_func = "".join(ckpt_func)
# If the original function didn't have self as its first argument, we
# would have added it.
- if len(orig_args) == 0 or orig_args[0] != 'self':
- orig_args.insert(0, 'self')
- code = ''.join(body)
- code = '\n'.join(' ' + line for line in code.split('\n'))
+ if len(orig_args) == 0 or orig_args[0] != "self":
+ orig_args.insert(0, "self")
+ code = "".join(body)
+ code = "\n".join(" " + line for line in code.split("\n"))
# as we need colossalai.utils.checkpoint, we need to import colossalai
# in forward function
diff --git a/colossalai/fx/graph_module.py b/colossalai/fx/graph_module.py
index ebb9975f27db..8429a9607f7a 100644
--- a/colossalai/fx/graph_module.py
+++ b/colossalai/fx/graph_module.py
@@ -1,32 +1,35 @@
import os
import warnings
from pathlib import Path
-from typing import Any, Dict, List, Optional, Set, Type, Union
+from typing import Any, Dict, Optional, Union
import torch
import torch.nn as nn
from torch.nn.modules.module import _addindent
try:
- from torch.fx.graph import Graph, PythonCode, _custom_builtins, _is_from_torch, _PyTreeCodeGen
- from torch.fx.graph_module import GraphModule, _EvalCacheLoader, _exec_with_source, _forward_from_src, _WrappedCall
+ from torch.fx.graph import Graph, PythonCode, _PyTreeCodeGen
+ from torch.fx.graph_module import GraphModule, _exec_with_source, _forward_from_src, _WrappedCall
from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
+
COLOGM = True
except:
from torch.fx.graph import Graph
from torch.fx.graph_module import GraphModule
+
COLOGM = False
if COLOGM:
class ColoGraphModule(GraphModule):
-
- def __init__(self,
- root: Union[torch.nn.Module, Dict[str, Any]],
- graph: Graph,
- class_name: str = 'GraphModule',
- ckpt_codegen: bool = True):
+ def __init__(
+ self,
+ root: Union[torch.nn.Module, Dict[str, Any]],
+ graph: Graph,
+ class_name: str = "GraphModule",
+ ckpt_codegen: bool = True,
+ ):
if ckpt_codegen:
graph.set_codegen(ActivationCheckpointCodeGen())
super().__init__(root, graph, class_name)
@@ -60,7 +63,7 @@ def recompile(self) -> PythonCode:
if isinstance(self._graph._codegen, _PyTreeCodeGen):
self._in_spec = self._graph._codegen.pytree_info.in_spec
self._out_spec = self._graph._codegen.pytree_info.out_spec
- python_code = self._graph.python_code(root_module='self')
+ python_code = self._graph.python_code(root_module="self")
self._code = python_code.src
# To split ckpt functions code and forward code
@@ -83,8 +86,8 @@ def recompile(self) -> PythonCode:
# bypass patching of torch.nn.Module.__call__ done while symbolic tracing.
cls_call = cls.__call__ if "__call__" in vars(cls) else None
- if '_wrapped_call' not in vars(cls):
- cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined]
+ if "_wrapped_call" not in vars(cls):
+ cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined]
def call_wrapped(self, *args, **kwargs):
return self._wrapped_call(self, *args, **kwargs)
@@ -108,7 +111,7 @@ def to_folder(self, folder: Union[str, os.PathLike], module_name: str = "FxModul
"""
folder = Path(folder)
Path(folder).mkdir(exist_ok=True)
- torch.save(self.state_dict(), folder / 'state_dict.pt')
+ torch.save(self.state_dict(), folder / "state_dict.pt")
tab = " " * 4
# we add import colossalai here
@@ -125,7 +128,13 @@ def __init__(self):
def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]:
safe_reprs = [
- nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d
+ nn.Linear,
+ nn.Conv1d,
+ nn.Conv2d,
+ nn.Conv3d,
+ nn.BatchNorm1d,
+ nn.BatchNorm2d,
+ nn.BatchNorm3d,
]
if type(module) in safe_reprs:
return f"{module.__repr__()}"
@@ -136,10 +145,10 @@ def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]:
for module_name, module in self.named_children():
module_str = _gen_model_repr(module_name, module)
if module_str is None:
- module_file = folder / f'{module_name}.pt'
+ module_file = folder / f"{module_name}.pt"
torch.save(module, module_file)
blobified_modules.append(module_name)
- module_repr = module.__repr__().replace('\r', ' ').replace('\n', ' ')
+ module_repr = module.__repr__().replace("\r", " ").replace("\n", " ")
module_str = f"torch.load(r'{module_file}') # {module_repr}"
model_str += f"{tab*2}self.{module_name} = {module_str}\n"
@@ -156,19 +165,20 @@ def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]:
model_str += f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n"
model_str += f"{_addindent(self.code, 4)}\n"
- module_file = folder / 'module.py'
+ module_file = folder / "module.py"
module_file.write_text(model_str)
- init_file = folder / '__init__.py'
- init_file.write_text('from .module import *')
+ init_file = folder / "__init__.py"
+ init_file.write_text("from .module import *")
if len(blobified_modules) > 0:
- warnings.warn("Was not able to save the following children modules as reprs -"
- f"saved as pickled files instead: {blobified_modules}")
+ warnings.warn(
+ "Was not able to save the following children modules as reprs -"
+ f"saved as pickled files instead: {blobified_modules}"
+ )
else:
class ColoGraphModule(GraphModule):
-
- def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = 'GraphModule'):
+ def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = "GraphModule"):
super().__init__(root, graph, class_name)
diff --git a/colossalai/fx/passes/adding_split_node_pass.py b/colossalai/fx/passes/adding_split_node_pass.py
index 245ba5d776da..99c8faaa0cc6 100644
--- a/colossalai/fx/passes/adding_split_node_pass.py
+++ b/colossalai/fx/passes/adding_split_node_pass.py
@@ -1,8 +1,6 @@
import numpy as np
import torch
import tqdm
-from torch.fx import symbolic_trace
-from torch.fx.node import Node
from colossalai.fx.passes.split_module import split_module
@@ -29,15 +27,15 @@ def construct_blocks(gm: torch.fx.GraphModule, limit=0.01):
accumulate_bwd_flop = 0
block_nodes = []
for node in gm.graph.nodes:
- if 'block_split' in node.name:
+ if "block_split" in node.name:
continue
accumulate_fwd_flop += node.fwd_flop
accumulate_bwd_flop += node.bwd_flop
if accumulate_fwd_flop + accumulate_bwd_flop >= per_block_flop:
with gm.graph.inserting_after(node):
- block_node = gm.graph.create_node('call_function', block_split)
- setattr(block_node, 'fwd_flop', accumulate_fwd_flop)
- setattr(block_node, 'bwd_flop', accumulate_bwd_flop)
+ block_node = gm.graph.create_node("call_function", block_split)
+ setattr(block_node, "fwd_flop", accumulate_fwd_flop)
+ setattr(block_node, "bwd_flop", accumulate_bwd_flop)
accumulate_fwd_flop = 0
accumulate_bwd_flop = 0
block_nodes.append(block_node)
@@ -47,7 +45,7 @@ def construct_blocks(gm: torch.fx.GraphModule, limit=0.01):
def remove_blocks(gm: torch.fx.GraphModule):
for node in gm.graph.nodes:
- if (node.op, node.target) == ('call_function', block_split):
+ if (node.op, node.target) == ("call_function", block_split):
gm.graph.erase_node(node)
@@ -55,8 +53,8 @@ def get_compute_costs(node_list):
num_nodes = len(node_list)
all_compute_cost = np.full((num_nodes, num_nodes), np.inf, dtype=np.float64)
- for start in tqdm.tqdm(range(num_nodes), desc='start pos', position=0):
- for end in tqdm.tqdm(range(start, num_nodes), desc='end pos', position=1, leave=False):
+ for start in tqdm.tqdm(range(num_nodes), desc="start pos", position=0):
+ for end in tqdm.tqdm(range(start, num_nodes), desc="end pos", position=1, leave=False):
selected_flops = [(node_list[i].fwd_flop + node_list[i].bwd_flop) for i in range(start, end + 1)]
all_compute_cost[start, end] = sum(selected_flops)
@@ -78,12 +76,14 @@ def do_dp_split_gpipe_impl(num_nodes, num_stages, num_microbatches, compute_cost
# record start node index for next stage in this partition
f_argmin = np.full((num_stages + 1, num_nodes + 1), -1, dtype=np.int32)
f[0, num_nodes] = 0
- for s in tqdm.tqdm(range(1, num_stages + 1), desc='stage', position=2, leave=False): # pylint: disable=too-many-nested-blocks
- for i in tqdm.tqdm(range(num_nodes - 1, -1, -1), desc='start node', position=3, leave=False):
- for k in tqdm.tqdm(range(num_nodes, i, -1), desc='mid node', position=4, leave=False):
+ for s in tqdm.tqdm(
+ range(1, num_stages + 1), desc="stage", position=2, leave=False
+ ): # pylint: disable=too-many-nested-blocks
+ for i in tqdm.tqdm(range(num_nodes - 1, -1, -1), desc="start node", position=3, leave=False):
+ for k in tqdm.tqdm(range(num_nodes, i, -1), desc="mid node", position=4, leave=False):
stage_cost = compute_costs[i, k - 1]
new_cost = f[s - 1, k] + stage_cost
- if (stage_cost <= max_compute_cost and new_cost < f[s, i]):
+ if stage_cost <= max_compute_cost and new_cost < f[s, i]:
f[s, i] = new_cost
f_stage_max[s, i] = max(f_stage_max[s - 1, k], stage_cost)
f_argmin[s, i] = k
@@ -113,7 +113,7 @@ def do_dp_split_gpipe(node_list, compute_costs, num_stages: int, num_microbatche
best_cost = np.inf
best_solution = None
last_max_compute_cost = 0.0
- gap = 1e6 # temporary magic number, unit: flops
+ gap = 1e6 # temporary magic number, unit: flops
for max_compute_cost in tqdm.tqdm(max_compute_costs):
# Pruning to reduce search space.
@@ -122,8 +122,9 @@ def do_dp_split_gpipe(node_list, compute_costs, num_stages: int, num_microbatche
if max_compute_cost - last_max_compute_cost < gap:
continue
- cost, solution = do_dp_split_gpipe_impl(len(node_list), num_stages, num_microbatches, compute_costs,
- max_compute_cost)
+ cost, solution = do_dp_split_gpipe_impl(
+ len(node_list), num_stages, num_microbatches, compute_costs, max_compute_cost
+ )
if cost < best_cost:
best_cost = cost
@@ -137,15 +138,15 @@ def do_dp_split_gpipe(node_list, compute_costs, num_stages: int, num_microbatche
# split_mode:
# 'node': fx_node
# 'block': many fx_nodes construct a block
-def gpipe_dp_split_pass(gm: torch.fx.GraphModule, pp_size: int, num_microbatches: int, mode='block', block_limit=0.01):
- assert mode in ['node', 'block']
+def gpipe_dp_split_pass(gm: torch.fx.GraphModule, pp_size: int, num_microbatches: int, mode="block", block_limit=0.01):
+ assert mode in ["node", "block"]
# nodes or blocks will be used in partition.
node_list = []
- if mode == 'node':
+ if mode == "node":
for node in gm.graph.nodes:
node_list.append(node)
- elif mode == 'block':
+ elif mode == "block":
node_list = construct_blocks(gm, limit=block_limit)
else:
pass
@@ -154,16 +155,16 @@ def gpipe_dp_split_pass(gm: torch.fx.GraphModule, pp_size: int, num_microbatches
best_cost, best_solution = do_dp_split_gpipe(node_list, compute_costs, pp_size, num_microbatches)
- for (_, next_start_node) in best_solution:
+ for _, next_start_node in best_solution:
if pp_size <= 1:
break
node = node_list[next_start_node]
with gm.graph.inserting_before(node):
- split_node = gm.graph.create_node('call_function', pipe_split)
+ split_node = gm.graph.create_node("call_function", pipe_split)
pp_size -= 1
# remove block node if possible
- if mode == 'block':
+ if mode == "block":
remove_blocks(gm)
gm.recompile()
@@ -178,7 +179,7 @@ def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int):
# To use avgcompute_split_pass, we need run meta_info_prop interpreter first.
# If nodes don't have meta info, this pass will fall back to normal balanced split pass.
check_node = list(mod_graph.nodes)[0]
- if 'tensor_meta' not in check_node.meta:
+ if "tensor_meta" not in check_node.meta:
return balanced_split_pass(gm, pp_size)
total_fwd_flop = 0
@@ -190,7 +191,7 @@ def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int):
for node in mod_graph.nodes:
if pp_size <= 1:
break
- if 'pipe_split' in node.name:
+ if "pipe_split" in node.name:
continue
accumulate_fwd_flop += node.fwd_flop
if accumulate_fwd_flop >= partition_flop:
@@ -199,7 +200,7 @@ def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int):
pp_size -= 1
partition_flop = total_fwd_flop // pp_size
with mod_graph.inserting_after(node):
- split_node = mod_graph.create_node('call_function', pipe_split)
+ split_node = mod_graph.create_node("call_function", pipe_split)
gm.recompile()
return gm
@@ -218,12 +219,12 @@ def avgnode_split_pass(gm: torch.fx.GraphModule, pp_size: int):
if accumulate_num_node >= avg_num_node:
accumulate_num_node = 0
pp_size -= 1
- if node.next.op == 'output':
+ if node.next.op == "output":
with mod_graph.inserting_before(node):
- split_node = mod_graph.create_node('call_function', pipe_split)
+ split_node = mod_graph.create_node("call_function", pipe_split)
else:
with mod_graph.inserting_after(node):
- split_node = mod_graph.create_node('call_function', pipe_split)
+ split_node = mod_graph.create_node("call_function", pipe_split)
gm.recompile()
return gm
@@ -250,18 +251,18 @@ def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int):
pp_size -= 1
# If the next node is output node, we will insert split annotation before
# node to make sure there is at least one node in last partition.
- if node.next.op == 'output':
+ if node.next.op == "output":
with mod_graph.inserting_before(node):
- split_node = mod_graph.create_node('call_function', pipe_split)
+ split_node = mod_graph.create_node("call_function", pipe_split)
else:
with mod_graph.inserting_after(node):
- split_node = mod_graph.create_node('call_function', pipe_split)
+ split_node = mod_graph.create_node("call_function", pipe_split)
if pp_size > 1:
node_counter = 0
for node in mod_graph.nodes:
if pp_size <= 1:
break
- if node.op == 'placeholder':
+ if node.op == "placeholder":
continue
elif node_counter == 0:
node_counter += 1
@@ -269,7 +270,7 @@ def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int):
pp_size -= 1
node_counter = 0
with mod_graph.inserting_before(node):
- split_node = mod_graph.create_node('call_function', pipe_split)
+ split_node = mod_graph.create_node("call_function", pipe_split)
gm.recompile()
return gm
@@ -283,7 +284,7 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int):
# To use balanced_split_pass_v2, we need run meta_info_prop interpreter first.
# If nodes don't have meta info, this pass will fall back to normal balanced split pass.
check_node = list(mod_graph.nodes)[0]
- if 'tensor_meta' not in check_node.meta:
+ if "tensor_meta" not in check_node.meta:
return balanced_split_pass(gm, pp_size)
total_element_size = 0
@@ -295,7 +296,7 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int):
for node in mod_graph.nodes:
if pp_size <= 1:
break
- if 'pipe_split' in node.name:
+ if "pipe_split" in node.name:
continue
accumulate_node_size += node.node_size
if accumulate_node_size >= partition_size:
@@ -304,7 +305,7 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int):
pp_size -= 1
partition_size = total_element_size // pp_size
with mod_graph.inserting_after(node):
- split_node = mod_graph.create_node('call_function', pipe_split)
+ split_node = mod_graph.create_node("call_function", pipe_split)
gm.recompile()
return gm
@@ -333,7 +334,7 @@ def uniform_split_pass(gm: torch.fx.GraphModule, pp_size: int):
accumulate_layer_amount = 0
pp_size -= 1
with mod_graph.inserting_after(node):
- split_node = mod_graph.create_node('call_function', pipe_split)
+ split_node = mod_graph.create_node("call_function", pipe_split)
gm.recompile()
return gm
@@ -346,7 +347,7 @@ def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule, merge_output
def split_callback(n: torch.fx.Node):
nonlocal part_idx
- if (n.op, n.target) == ('call_function', pipe_split):
+ if (n.op, n.target) == ("call_function", pipe_split):
part_idx += 1
return part_idx
@@ -355,7 +356,7 @@ def split_callback(n: torch.fx.Node):
for name, submodule in split_mod.named_modules():
if isinstance(submodule, torch.fx.GraphModule):
for node in submodule.graph.nodes:
- if (node.op, node.target) == ('call_function', pipe_split):
+ if (node.op, node.target) == ("call_function", pipe_split):
submodule.graph.erase_node(node)
submodule.recompile()
split_submodules.append(submodule)
diff --git a/colossalai/fx/passes/concrete_info_prop.py b/colossalai/fx/passes/concrete_info_prop.py
index 81ac64205528..5440a4eadbbf 100644
--- a/colossalai/fx/passes/concrete_info_prop.py
+++ b/colossalai/fx/passes/concrete_info_prop.py
@@ -1,5 +1,5 @@
from dataclasses import asdict
-from typing import Any, Dict, List, NamedTuple, Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.fx
@@ -85,10 +85,10 @@ def run_node(self, n: Node) -> Any:
self._is_proped = True
result, meta_info = super().run_node(n)
- n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
+ n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
# TODO: the attribute node_size should be removed in the future
- setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0))
- n.meta['type'] = type(result)
+ setattr(n, "node_size", n.meta.get("fwd_mem_tmp", 0) + n.meta.get("fwd_mem_out", 0))
+ n.meta["type"] = type(result)
# retain the autograd graph
for param in self.module.parameters():
@@ -98,7 +98,7 @@ def run_node(self, n: Node) -> Any:
# Main Node running APIs
@compatibility(is_backward_compatible=True)
- def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+ def placeholder(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``placeholder`` node. Note that this is stateful:
``Interpreter`` maintains an internal iterator over
@@ -119,7 +119,7 @@ def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict
return super().placeholder(target, args, kwargs), GraphInfo()
@compatibility(is_backward_compatible=True)
- def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+ def get_attr(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``get_attr`` node. Will retrieve an attribute
value from the ``Module`` hierarchy of ``self.module``.
@@ -138,7 +138,7 @@ def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[st
return super().get_attr(target, args, kwargs), GraphInfo()
@compatibility(is_backward_compatible=True)
- def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+ def call_function(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_function`` node with meta tensor and return the result and its meta profile.
@@ -157,7 +157,7 @@ def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Di
return profile_function(target, self.device)(*args, **kwargs)
@compatibility(is_backward_compatible=True)
- def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+ def call_method(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_method`` node with meta tensor and return the result and its meta profile.
@@ -175,7 +175,7 @@ def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict
return profile_method(target, self.device)(*args, **kwargs)
@compatibility(is_backward_compatible=True)
- def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+ def call_module(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_module`` node with meta tensor and return the result and its meta profile.
@@ -197,7 +197,7 @@ def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict
return profile_module(submod, self.device)(*args, **kwargs)
@compatibility(is_backward_compatible=True)
- def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+ def output(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute an ``output`` node. This really just retrieves
the value referenced by the ``output`` node and returns it.
@@ -228,7 +228,7 @@ def propagate(self, *args):
"""
return self.run(*args)
- def summary(self, unit: str = 'MB') -> str:
+ def summary(self, unit: str = "MB") -> str:
"""
Summarizes the memory and FLOPs statistics of the `GraphModule` in
tabular format. Note that this API requires the ``tabulate`` module
@@ -238,9 +238,11 @@ def summary(self, unit: str = 'MB') -> str:
try:
from tabulate import tabulate
except ImportError:
- print("`summary` relies on the library `tabulate`, "
- "which could not be found on this machine. Run `pip "
- "install tabulate` to install the library.")
+ print(
+ "`summary` relies on the library `tabulate`, "
+ "which could not be found on this machine. Run `pip "
+ "install tabulate` to install the library."
+ )
assert self._is_proped, "Please call `interp.run(input)` before calling `interp.summary()`."
@@ -249,10 +251,10 @@ def summary(self, unit: str = 'MB') -> str:
def mem_repr(mem: int) -> str:
unit_divisor_map = {
- 'kb': 1024,
- 'mb': 1024**2,
- 'gb': 1024**3,
- 'tb': 1024**4,
+ "kb": 1024,
+ "mb": 1024**2,
+ "gb": 1024**3,
+ "tb": 1024**4,
}
return f"{mem / unit_divisor_map[unit.lower()]:.2f} {unit.upper()}"
@@ -261,30 +263,32 @@ def time_repr(time: float):
for node in self.module.graph.nodes:
node: Node
- node_summaries.append([
- node.op,
- str(node),
- time_repr(node.meta['fwd_time']),
- time_repr(node.meta['bwd_time']),
- node.meta['save_fwd_in'],
- mem_repr(node.meta['fwd_mem_out']),
- mem_repr(node.meta['fwd_mem_tmp']),
- mem_repr(node.meta['bwd_mem_out']),
- mem_repr(node.meta['bwd_mem_tmp']),
- ])
+ node_summaries.append(
+ [
+ node.op,
+ str(node),
+ time_repr(node.meta["fwd_time"]),
+ time_repr(node.meta["bwd_time"]),
+ node.meta["save_fwd_in"],
+ mem_repr(node.meta["fwd_mem_out"]),
+ mem_repr(node.meta["fwd_mem_tmp"]),
+ mem_repr(node.meta["bwd_mem_out"]),
+ mem_repr(node.meta["bwd_mem_tmp"]),
+ ]
+ )
# Use the ``tabulate`` library to create a well-formatted table
# presenting our summary information
headers: List[str] = [
- 'Op type',
- 'Op',
- 'Forward time',
- 'Backward time',
- 'SAVE_FWD_IN',
- 'FWD_OUT',
- 'FWD_TMP',
- 'BWD_OUT',
- 'BWD_TMP',
+ "Op type",
+ "Op",
+ "Forward time",
+ "Backward time",
+ "SAVE_FWD_IN",
+ "FWD_OUT",
+ "FWD_TMP",
+ "BWD_OUT",
+ "BWD_TMP",
]
- return tabulate(node_summaries, headers=headers, stralign='right')
+ return tabulate(node_summaries, headers=headers, stralign="right")
diff --git a/colossalai/fx/passes/experimental/adding_shape_consistency_pass.py b/colossalai/fx/passes/experimental/adding_shape_consistency_pass.py
index 4571bd93a790..3d032a27db63 100644
--- a/colossalai/fx/passes/experimental/adding_shape_consistency_pass.py
+++ b/colossalai/fx/passes/experimental/adding_shape_consistency_pass.py
@@ -1,14 +1,11 @@
-import torch
-from typing import List
-from torch.fx import symbolic_trace
-from torch.fx.node import Node
-from colossalai.fx.passes.split_module import split_module
-from colossalai.tensor.shape_consistency import ShapeConsistencyManager
-from colossalai.device.device_mesh import DeviceMesh
-from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
import builtins
import operator
-from copy import deepcopy
+from typing import List
+
+import torch
+
+from colossalai.tensor.shape_consistency import ShapeConsistencyManager
+from colossalai.tensor.sharding_spec import ShardingSpec
def apply(*args, **kwargs):
@@ -24,16 +21,16 @@ def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int], devi
origin_node_sharding_spec_dict = {}
for node_index, (node, strategy_index) in enumerate(zip(nodes, solution)):
strategies_vector = node.strategies_vector
- setattr(node, 'best_strategy', strategies_vector[strategy_index])
- setattr(node, 'sharding_spec', strategies_vector[strategy_index].output_sharding_spec)
+ setattr(node, "best_strategy", strategies_vector[strategy_index])
+ setattr(node, "sharding_spec", strategies_vector[strategy_index].output_sharding_spec)
origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].output_sharding_spec
# apply the sharding spec of parameters
for node in nodes:
- if node.op == 'call_module':
+ if node.op == "call_module":
target_module = node.graph.owning_module.get_submodule(node.target)
origin_sharding_spec = ShardingSpec(device_mesh, target_module.weight.shape, {})
- setattr(target_module.weight, 'sharding_spec', origin_sharding_spec)
+ setattr(target_module.weight, "sharding_spec", origin_sharding_spec)
target_weight_sharding_spec = node.best_strategy.input_shardings[1]
target_module.weight.data = target_module.weight.data.permute((1, 0, 2, 3))
apply(target_module.weight, target_weight_sharding_spec)
@@ -51,10 +48,10 @@ def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int], devi
# add above dicts into graph
for node in nodes:
- if node.op != 'placeholder':
+ if node.op != "placeholder":
with mod_graph.inserting_before(node):
- input_specs_node = mod_graph.create_node('placeholder', target='sharding_spec_convert_dict')
- origin_specs_node = mod_graph.create_node('placeholder', target='origin_node_sharding_spec_dict')
+ input_specs_node = mod_graph.create_node("placeholder", target="sharding_spec_convert_dict")
+ origin_specs_node = mod_graph.create_node("placeholder", target="origin_node_sharding_spec_dict")
break
return sharding_spec_convert_dict, origin_node_sharding_spec_dict
@@ -70,13 +67,13 @@ def shape_consistency_pass(gm: torch.fx.GraphModule):
node_to_index_dict = {}
index = 0
for node in nodes:
- if node.target == 'sharding_spec_convert_dict':
+ if node.target == "sharding_spec_convert_dict":
input_dict_node = node
continue
- if node.target == 'origin_node_sharding_spec_dict':
+ if node.target == "origin_node_sharding_spec_dict":
origin_dict_node = node
continue
- if not hasattr(node, 'best_strategy'):
+ if not hasattr(node, "best_strategy"):
continue
node_to_index_dict[node] = index
index += 1
@@ -84,28 +81,28 @@ def shape_consistency_pass(gm: torch.fx.GraphModule):
# add shape consistency apply function into graph
for node in nodes:
- if not hasattr(node, 'best_strategy'):
+ if not hasattr(node, "best_strategy"):
continue
with mod_graph.inserting_after(node):
- origin_spec_node = mod_graph.create_node('call_function',
- operator.getitem,
- args=(origin_dict_node, node_to_index_dict[node]))
+ origin_spec_node = mod_graph.create_node(
+ "call_function", operator.getitem, args=(origin_dict_node, node_to_index_dict[node])
+ )
with mod_graph.inserting_after(origin_spec_node):
- set_sharding_spec_node = mod_graph.create_node('call_function',
- builtins.setattr,
- args=(node, 'sharding_spec', origin_spec_node))
+ set_sharding_spec_node = mod_graph.create_node(
+ "call_function", builtins.setattr, args=(node, "sharding_spec", origin_spec_node)
+ )
for user_node in node.strategies_vector.successor_nodes:
node_index = user_node.strategies_vector.predecessor_nodes.index(node)
with mod_graph.inserting_before(user_node):
- input_specs_node = mod_graph.create_node('call_function',
- operator.getitem,
- args=(input_dict_node, node_to_index_dict[node]))
+ input_specs_node = mod_graph.create_node(
+ "call_function", operator.getitem, args=(input_dict_node, node_to_index_dict[node])
+ )
with mod_graph.inserting_before(user_node):
- sharding_spec_node = mod_graph.create_node('call_function',
- operator.getitem,
- args=(input_specs_node, node_index))
+ sharding_spec_node = mod_graph.create_node(
+ "call_function", operator.getitem, args=(input_specs_node, node_index)
+ )
with mod_graph.inserting_before(user_node):
- shape_consistency_node = mod_graph.create_node('call_function', apply, args=(node, sharding_spec_node))
+ shape_consistency_node = mod_graph.create_node("call_function", apply, args=(node, sharding_spec_node))
return gm
diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py
index ab203dfd7440..1720aa58da2b 100644
--- a/colossalai/fx/passes/meta_info_prop.py
+++ b/colossalai/fx/passes/meta_info_prop.py
@@ -109,13 +109,13 @@ def extract_tensor_meta(obj):
return TensorMetadata(None, None, False, None, 0, False)
tensor_meta = tree_map(extract_tensor_meta, result)
- n.meta['tensor_meta'] = tensor_meta
- n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
+ n.meta["tensor_meta"] = tensor_meta
+ n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
# TODO: the attribute node_size should be removed in the future
- setattr(n, 'node_size', activation_size(n.meta.get('fwd_out', 0)) + activation_size(n.meta.get('fwd_tmp', 0)))
- setattr(n, 'fwd_flop', n.meta.get('fwd_flop', 0))
- setattr(n, 'bwd_flop', n.meta.get('bwd_flop', 0))
- n.meta['type'] = type(result)
+ setattr(n, "node_size", activation_size(n.meta.get("fwd_out", 0)) + activation_size(n.meta.get("fwd_tmp", 0)))
+ setattr(n, "fwd_flop", n.meta.get("fwd_flop", 0))
+ setattr(n, "bwd_flop", n.meta.get("bwd_flop", 0))
+ n.meta["type"] = type(result)
# retain the autograd graph
for param in self.module.parameters():
@@ -125,7 +125,7 @@ def extract_tensor_meta(obj):
# Main Node running APIs
@compatibility(is_backward_compatible=True)
- def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+ def placeholder(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``placeholder`` node. Note that this is stateful:
``Interpreter`` maintains an internal iterator over
@@ -146,7 +146,7 @@ def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict
return super().placeholder(target, args, kwargs), GraphInfo()
@compatibility(is_backward_compatible=True)
- def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+ def get_attr(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``get_attr`` node. Will retrieve an attribute
value from the ``Module`` hierarchy of ``self.module``.
@@ -165,7 +165,7 @@ def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[st
return super().get_attr(target, args, kwargs), GraphInfo()
@compatibility(is_backward_compatible=True)
- def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+ def call_function(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_function`` node with meta tensor and return the result and its meta profile.
@@ -184,7 +184,7 @@ def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Di
return profile_function(target)(*args, **kwargs)
@compatibility(is_backward_compatible=True)
- def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+ def call_method(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_method`` node with meta tensor and return the result and its meta profile.
@@ -202,7 +202,7 @@ def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict
return profile_method(target)(*args, **kwargs)
@compatibility(is_backward_compatible=True)
- def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+ def call_module(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_module`` node with meta tensor and return the result and its meta profile.
@@ -224,7 +224,7 @@ def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict
return profile_module(submod)(*args, **kwargs)
@compatibility(is_backward_compatible=True)
- def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+ def output(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute an ``output`` node. This really just retrieves
the value referenced by the ``output`` node and returns it.
@@ -240,7 +240,7 @@ def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str,
result (Any): The argument value that was retrieved
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
"""
- if hasattr(args[0], '_tensor'):
+ if hasattr(args[0], "_tensor"):
return args[0], GraphInfo(fwd_in=[args[0]._tensor])
return args[0], GraphInfo(save_fwd_in=True)
@@ -257,7 +257,7 @@ def propagate(self, *args):
"""
return super().run(*args)
- def summary(self, unit: str = 'MB') -> str:
+ def summary(self, unit: str = "MB") -> str:
"""
Summarizes the memory and FLOPs statistics of the `GraphModule` in
tabular format. Note that this API requires the ``tabulate`` module
@@ -267,9 +267,11 @@ def summary(self, unit: str = 'MB') -> str:
try:
from tabulate import tabulate
except ImportError:
- print("`summary` relies on the library `tabulate`, "
- "which could not be found on this machine. Run `pip "
- "install tabulate` to install the library.")
+ print(
+ "`summary` relies on the library `tabulate`, "
+ "which could not be found on this machine. Run `pip "
+ "install tabulate` to install the library."
+ )
assert self._is_proped, "Please call `interp.run(input)` before calling `interp.summary()`."
@@ -278,10 +280,10 @@ def summary(self, unit: str = 'MB') -> str:
def mem_repr(mem: int) -> str:
unit_divisor_map = {
- 'kb': 1024,
- 'mb': 1024**2,
- 'gb': 1024**3,
- 'tb': 1024**4,
+ "kb": 1024,
+ "mb": 1024**2,
+ "gb": 1024**3,
+ "tb": 1024**4,
}
return f"{mem / unit_divisor_map[unit.lower()]:.2f} {unit.upper()}"
@@ -292,35 +294,37 @@ def flops_repr(flop: int) -> str:
for node in self.module.graph.nodes:
node: Node
accumulate_size += calculate_fwd_out(node) + calculate_fwd_tmp(node)
- node_summaries.append([
- node.op,
- str(node),
- flops_repr(node.meta['fwd_flop']),
- flops_repr(node.meta['bwd_flop']),
- mem_repr(accumulate_size),
- mem_repr(calculate_fwd_in(node)),
- mem_repr(calculate_fwd_out(node)),
- mem_repr(calculate_fwd_tmp(node)),
- mem_repr(node.meta['bwd_mem_out']),
- mem_repr(node.meta['bwd_mem_tmp']),
- ])
+ node_summaries.append(
+ [
+ node.op,
+ str(node),
+ flops_repr(node.meta["fwd_flop"]),
+ flops_repr(node.meta["bwd_flop"]),
+ mem_repr(accumulate_size),
+ mem_repr(calculate_fwd_in(node)),
+ mem_repr(calculate_fwd_out(node)),
+ mem_repr(calculate_fwd_tmp(node)),
+ mem_repr(node.meta["bwd_mem_out"]),
+ mem_repr(node.meta["bwd_mem_tmp"]),
+ ]
+ )
# Use the ``tabulate`` library to create a well-formatted table
# presenting our summary information
headers: List[str] = [
- 'Op type',
- 'Op',
- 'Forward FLOPs',
- 'Backward FLOPs',
- 'Accumulated Memory',
- 'FWD_IN',
- 'FWD_OUT',
- 'FWD_TMP',
- 'BWD_OUT',
- 'BWD_TMP',
+ "Op type",
+ "Op",
+ "Forward FLOPs",
+ "Backward FLOPs",
+ "Accumulated Memory",
+ "FWD_IN",
+ "FWD_OUT",
+ "FWD_TMP",
+ "BWD_OUT",
+ "BWD_TMP",
]
- return tabulate(node_summaries, headers=headers, stralign='right')
+ return tabulate(node_summaries, headers=headers, stralign="right")
def metainfo_trace(gm: torch.fx.GraphModule, *args, verbose: bool = False, unit: str = "MB", **kwargs) -> None:
@@ -344,15 +348,16 @@ def metainfo_trace(gm: torch.fx.GraphModule, *args, verbose: bool = False, unit:
Returns:
torch.fx.GraphModule: The ``GraphModule`` annotated with MetaInfo.
"""
- device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
+ device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
interp = MetaInfoProp(gm.to(device))
if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor
+
args = tree_map(lambda x: MetaTensor(x, fake_device=device), args)
kwargs = tree_map(lambda x: MetaTensor(x, fake_device=device), kwargs)
interp.propagate(*args, **kwargs)
if verbose:
interp.summary(unit)
- gm.to('cpu')
+ gm.to("cpu")
del interp
return gm
diff --git a/colossalai/fx/passes/passes_for_gpt2_test.py b/colossalai/fx/passes/passes_for_gpt2_test.py
index efdd34a01fe0..73379f73689c 100644
--- a/colossalai/fx/passes/passes_for_gpt2_test.py
+++ b/colossalai/fx/passes/passes_for_gpt2_test.py
@@ -5,7 +5,6 @@
from packaging import version
from torch.fx._compatibility import compatibility
from torch.fx.graph_module import GraphModule
-from torch.fx.node import Node
from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, pipe_split
from colossalai.fx.passes.meta_info_prop import TensorMetadata
@@ -13,9 +12,9 @@
def customized_split_pass_for_gpt2(gm: torch.fx.GraphModule, pp_size: int, partition_list: List[int]):
- '''
+ """
This pass is only used to do the gpt2 performance test, it may move into adding_split_node_pass.py, and will be deprecated in future.
- '''
+ """
mod_graph = gm.graph
valid_children_size = 0
valid_children = []
@@ -39,40 +38,40 @@ def customized_split_pass_for_gpt2(gm: torch.fx.GraphModule, pp_size: int, parti
part_index += 1
pp_size -= 1
with mod_graph.inserting_after(node):
- split_node = mod_graph.create_node('call_function', pipe_split)
+ split_node = mod_graph.create_node("call_function", pipe_split)
gm.recompile()
return gm
def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule):
- '''
+ """
This pass will be used in gpt2 test, only a part of changes may be added into
split_with_split_nodes_pass, and it will be deprecated in future.
- '''
+ """
part_idx = 0
def eliminate_unused_placeholders(gm):
for node in gm.graph.nodes:
- if node.op == 'placeholder':
+ if node.op == "placeholder":
if not len(node.users):
gm.graph.erase_node(node)
gm.recompile()
return gm
def refill_outputs_and_placeholders(gm, next_partition_placeholders):
- '''
+ """
This method is used to eliminate the outputs in previous partition which is unused in next partition.
In split module pass, it treats partitions as a DAG, but we need treat them as a single direction linked list in pipeline parallel.
The difference is if a output from partition 0 is an input argument of partition 3, the DAG will not transfer it
to partition 1 and partition 2. However, in single direction linked list, we need to do so.
- '''
+ """
output_type = None
output_args = []
non_output_list = []
new_placeholder_list = []
for node in gm.graph.nodes:
- if node.op == 'output':
+ if node.op == "output":
if isinstance(node.args[0], (tuple, list)):
output_type = node.args[0].__class__
output_args.extend([n.name for n in node.args[0]])
@@ -114,7 +113,7 @@ def refill_outputs_and_placeholders(gm, next_partition_placeholders):
continue
for node in gm.graph.nodes:
- if node.op == 'placeholder':
+ if node.op == "placeholder":
new_placeholder_list.append(node.name)
if output_type is not None:
gm.graph.output(output_type(output_args))
@@ -125,7 +124,7 @@ def refill_outputs_and_placeholders(gm, next_partition_placeholders):
def split_callback(n: torch.fx.Node):
nonlocal part_idx
- if (n.op, n.target) == ('call_function', pipe_split):
+ if (n.op, n.target) == ("call_function", pipe_split):
part_idx += 1
return part_idx
@@ -134,7 +133,7 @@ def split_callback(n: torch.fx.Node):
for name, submodule in split_mod.named_modules():
if isinstance(submodule, torch.fx.GraphModule):
for node in submodule.graph.nodes:
- if (node.op, node.target) == ('call_function', pipe_split):
+ if (node.op, node.target) == ("call_function", pipe_split):
submodule.graph.erase_node(node)
submodule.recompile()
split_submodules.append(submodule)
@@ -200,13 +199,12 @@ def _gen_all_ancestors_set(node):
_gen_all_ancestors_set(node)
for n in list(all_ancestors):
- if n.op != 'placeholder' and n._fx_partition > partition_name:
+ if n.op != "placeholder" and n._fx_partition > partition_name:
n._fx_partition = partition_name
- def record_cross_partition_use(def_node: torch.fx.node.Node,
- use_node: Optional[torch.fx.node.Node]): # noqa: B950
- def_partition_name = getattr(def_node, '_fx_partition', None)
- use_partition_name = getattr(use_node, '_fx_partition', None)
+ def record_cross_partition_use(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950
+ def_partition_name = getattr(def_node, "_fx_partition", None)
+ use_partition_name = getattr(use_node, "_fx_partition", None)
if def_partition_name != use_partition_name:
# if 'tensor_meta' in def_node.meta:
# if not _node_with_all_tensor_element(def_node.meta['tensor_meta']):
@@ -237,7 +235,7 @@ def record_cross_partition_use(def_node: torch.fx.node.Node,
if node.op in ["placeholder"]:
continue
- if node.op == 'output':
+ if node.op == "output":
# partition_name = str(split_callback(node))
# def _set_output_args_partition(n, partition_name):
# n._fx_partition = partition_name
@@ -252,12 +250,12 @@ def record_cross_partition_use(def_node: torch.fx.node.Node,
partitions[partition_name] = partition = Partition(partition_name)
partition.node_names.append(node.name)
- origin_partition_name = getattr(node, '_fx_partition', None)
+ origin_partition_name = getattr(node, "_fx_partition", None)
if origin_partition_name is None:
node._fx_partition = partition_name
torch.fx.graph.map_arg(node.args, lambda def_node: record_cross_partition_use(def_node, node))
- torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950
+ torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950
# find partitions with no dependencies
root_partitions: List[str] = []
@@ -287,7 +285,7 @@ def record_cross_partition_use(def_node: torch.fx.node.Node,
# Transform nodes and collect targets for partition's submodule
for node in m.graph.nodes:
- if hasattr(node, '_fx_partition'):
+ if hasattr(node, "_fx_partition"):
partition = partitions[node._fx_partition]
# swap out old graph nodes in kw/args with references to new nodes in this submodule
@@ -295,26 +293,24 @@ def record_cross_partition_use(def_node: torch.fx.node.Node,
gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n])
gathered_kwargs = torch.fx.graph.map_arg(node.kwargs, lambda n: environment[n])
- if node.op not in ['call_module', 'get_attr']:
+ if node.op not in ["call_module", "get_attr"]:
target = node.target
else:
- target_atoms = node.target.split('.')
+ target_atoms = node.target.split(".")
target_attr = m
for atom in target_atoms:
if not hasattr(target_attr, atom):
- raise RuntimeError(f'Operator target {node.target} not found!')
+ raise RuntimeError(f"Operator target {node.target} not found!")
target_attr = getattr(target_attr, atom)
# target = target_atoms[-1]
- target = '_'.join(target_atoms)
+ target = "_".join(target_atoms)
partition.targets[target] = target_attr
assert isinstance(gathered_args, tuple)
assert isinstance(gathered_kwargs, dict)
- new_node = partition.graph.create_node(op=node.op,
- target=target,
- args=gathered_args,
- kwargs=gathered_kwargs,
- name=node.name)
+ new_node = partition.graph.create_node(
+ op=node.op, target=target, args=gathered_args, kwargs=gathered_kwargs, name=node.name
+ )
new_node.meta = node.meta.copy()
partition.environment[node] = new_node
@@ -323,14 +319,14 @@ def record_cross_partition_use(def_node: torch.fx.node.Node,
base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {}
for node in m.graph.nodes:
- if node.op == 'placeholder':
- if version.parse(torch.__version__) < version.parse('1.11.0'):
+ if node.op == "placeholder":
+ if version.parse(torch.__version__) < version.parse("1.11.0"):
base_mod_env[node.name] = base_mod_graph.placeholder(node.name, type_expr=node.type)
else:
default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty
- base_mod_env[node.name] = base_mod_graph.placeholder(node.name,
- type_expr=node.type,
- default_value=default_value)
+ base_mod_env[node.name] = base_mod_graph.placeholder(
+ node.name, type_expr=node.type, default_value=default_value
+ )
base_mod_env[node.name].meta = node.meta.copy()
# Do some things iterating over the partitions in topological order again:
@@ -344,13 +340,14 @@ def record_cross_partition_use(def_node: torch.fx.node.Node,
# Set correct output values
output_vals = tuple(partition.environment[orig_nodes[name]] for name in partition.outputs)
- output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment]
+ output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment]
partition.graph.output(output_vals)
# Construct GraphModule for this partition
- submod_name = f'submod_{partition_name}'
- base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(partition.targets,
- partition.graph) # noqa: B950
+ submod_name = f"submod_{partition_name}"
+ base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(
+ partition.targets, partition.graph
+ ) # noqa: B950
# Emit call in base graph to this submodule
output_val = base_mod_graph.call_module(submod_name, tuple(base_mod_env[name] for name in partition.inputs))
@@ -358,14 +355,14 @@ def record_cross_partition_use(def_node: torch.fx.node.Node,
# Unpack multiple return values from submodule
output_val_proxy = torch.fx.proxy.Proxy(output_val)
for i, output_name in enumerate(partition.outputs):
- base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index]
+ base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index]
else:
if not partition.outputs:
continue
base_mod_env[list(partition.outputs)[0]] = output_val
for node in m.graph.nodes:
- if node.op == 'output':
- base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950
+ if node.op == "output":
+ base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950
return torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)
diff --git a/colossalai/fx/passes/shard_1d_pass.py b/colossalai/fx/passes/shard_1d_pass.py
index d2bad06bb45a..be8261f2a3f4 100644
--- a/colossalai/fx/passes/shard_1d_pass.py
+++ b/colossalai/fx/passes/shard_1d_pass.py
@@ -1,19 +1,32 @@
+import operator
+
import torch
import torch.nn as nn
-import operator
-from colossalai.tensor import ProcessGroup
-from colossalai.tensor.distspec import ShardSpec
-from colossalai.tensor.compute_spec import ComputePattern, ComputeSpec
+
+from colossalai.legacy.tensor import ProcessGroup
+from colossalai.legacy.tensor.compute_spec import ComputePattern, ComputeSpec
+from colossalai.legacy.tensor.distspec import ShardSpec
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
ELEMENTWISE_FUNC_OP = [
- torch.add, operator.add, torch.abs, torch.cos, torch.exp, torch.mul, operator.mul, operator.floordiv,
- operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout
+ torch.add,
+ operator.add,
+ torch.abs,
+ torch.cos,
+ torch.exp,
+ torch.mul,
+ operator.mul,
+ operator.floordiv,
+ operator.truediv,
+ operator.neg,
+ torch.multiply,
+ torch.nn.functional.relu,
+ torch.nn.functional.dropout,
]
def weight_split(weight: torch.nn.parameter.Parameter, dim: int, col_normal: bool) -> torch.nn.parameter.Parameter:
- """weight_split
+ """weight_split
split a nn.Parameter
Args:
@@ -60,9 +73,9 @@ def row_shard_linear_pass(gm: torch.fx.GraphModule):
def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: ProcessGroup):
"""
- This IR pass checks for transformer MLP like structure and annotate column and row sharding to the linear layers.
+ This IR pass checks for transformer MLP like structure and annotate column and row sharding to the linear layers.
"""
- #TODO: Needs to handle special cases, like x = linear(x) + linear(x)
+ # TODO: Needs to handle special cases, like x = linear(x) + linear(x)
graph = graph_module.graph
world_size = process_group.world_size()
@@ -70,7 +83,7 @@ def _traverse_and_annotate(node, start_tracking, annotation_record, world_size):
# traverse the graph to look for consecutive linear layers
is_linear_module = False
- if node.op == 'call_module':
+ if node.op == "call_module":
# look for the linear layer
module = node.graph.owning_module.get_submodule(node.target)
if isinstance(module, nn.Linear):
@@ -80,31 +93,31 @@ def _traverse_and_annotate(node, start_tracking, annotation_record, world_size):
# it means the first linear has been found and the current module
# is the second linear
# set the current linear module to be row-sharded
- annotation_record['row'] = module
+ annotation_record["row"] = module
for shard_type, module in annotation_record.items():
# add row sharding spec
- if shard_type == 'row':
+ if shard_type == "row":
dist_spec = ShardSpec(dims=[-1], num_partitions=[world_size])
comp_spec = ComputeSpec(ComputePattern.TP1D)
- setattr(module.weight, 'pg', process_group)
- setattr(module.weight, 'dist_spec', dist_spec)
- setattr(module.weight, 'comp_spec', comp_spec)
- elif shard_type == 'col':
+ setattr(module.weight, "pg", process_group)
+ setattr(module.weight, "dist_spec", dist_spec)
+ setattr(module.weight, "comp_spec", comp_spec)
+ elif shard_type == "col":
weight_dist_spec = ShardSpec(dims=[0], num_partitions=[world_size])
weight_comp_spec = ComputeSpec(ComputePattern.TP1D)
weight_comp_spec.output_replicate = False
- setattr(module.weight, 'pg', process_group)
- setattr(module.weight, 'dist_spec', weight_dist_spec)
- setattr(module.weight, 'comp_spec', weight_comp_spec)
+ setattr(module.weight, "pg", process_group)
+ setattr(module.weight, "dist_spec", weight_dist_spec)
+ setattr(module.weight, "comp_spec", weight_comp_spec)
if module.bias is not None:
bias_dist_spec = ShardSpec(dims=[0], num_partitions=[world_size])
bias_comp_spec = ComputeSpec(ComputePattern.TP1D)
bias_comp_spec.output_replicate = False
- setattr(module.bias, 'pg', process_group)
- setattr(module.bias, 'dist_spec', bias_dist_spec)
- setattr(module.bias, 'comp_spec', bias_comp_spec)
+ setattr(module.bias, "pg", process_group)
+ setattr(module.bias, "dist_spec", bias_dist_spec)
+ setattr(module.bias, "comp_spec", bias_comp_spec)
start_tracking = False
annotation_record.clear()
else:
@@ -112,16 +125,16 @@ def _traverse_and_annotate(node, start_tracking, annotation_record, world_size):
# it means the current layer is the first linear
# set the linear layer to be col-sharded
start_tracking = True
- annotation_record['col'] = module
+ annotation_record["col"] = module
if start_tracking and not is_linear_module:
# check against the white list
# if non-element wise op is found, we reset the tracking
- if node.op == 'call_module':
+ if node.op == "call_module":
module = node.graph.owning_module.get_submodule(node.target)
if module.__class__ not in ELEMENTWISE_MODULE_OP:
start_tracking = False
- elif node.op == 'call_function' or node.op == 'call_method':
+ elif node.op == "call_function" or node.op == "call_method":
if node.target not in ELEMENTWISE_FUNC_OP:
start_tracking = False
elif len(node.users.keys()) > 1:
diff --git a/colossalai/fx/passes/split_module.py b/colossalai/fx/passes/split_module.py
index 61ed037ab7a1..67a2432595d6 100644
--- a/colossalai/fx/passes/split_module.py
+++ b/colossalai/fx/passes/split_module.py
@@ -25,12 +25,14 @@ def __init__(self, name: str):
self.targets: Dict[str, Any] = {}
def __repr__(self) -> str:
- return f"name: {self.name},\n" \
- f" nodes: {self.node_names},\n" \
- f" inputs: {self.inputs},\n" \
- f" outputs: {self.outputs},\n" \
- f" partitions dependent on: {self.partitions_dependent_on},\n" \
+ return (
+ f"name: {self.name},\n"
+ f" nodes: {self.node_names},\n"
+ f" inputs: {self.inputs},\n"
+ f" outputs: {self.outputs},\n"
+ f" partitions dependent on: {self.partitions_dependent_on},\n"
f" partition dependents: {self.partition_dependents}"
+ )
# Creates subgraphs out of main graph
@@ -117,10 +119,9 @@ def forward(self, x, y):
partitions: Dict[str, Partition] = {}
orig_nodes: Dict[str, torch.fx.node.Node] = {}
- def record_cross_partition_use(def_node: torch.fx.node.Node,
- use_node: Optional[torch.fx.node.Node]): # noqa: B950
- def_partition_name = getattr(def_node, '_fx_partition', None)
- use_partition_name = getattr(use_node, '_fx_partition', None)
+ def record_cross_partition_use(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950
+ def_partition_name = getattr(def_node, "_fx_partition", None)
+ use_partition_name = getattr(use_node, "_fx_partition", None)
if def_partition_name != use_partition_name:
if def_partition_name is not None:
def_partition = partitions[def_partition_name]
@@ -134,7 +135,7 @@ def record_cross_partition_use(def_node: torch.fx.node.Node,
if def_partition_name is not None:
use_partition.partitions_dependent_on.setdefault(def_partition_name)
- def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950
+ def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950
def_partition_name = getattr(def_node, "_fx_partition", None)
use_partition_name = getattr(use_node, "_fx_partition", None)
if def_partition_name != use_partition_name:
@@ -161,7 +162,7 @@ def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node
if node.op in ["placeholder"]:
continue
- if node.op == 'output':
+ if node.op == "output":
if merge_output:
torch.fx.graph.map_arg(node.args[0], lambda n: record_output(n, node.prev))
else:
@@ -178,7 +179,7 @@ def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node
node._fx_partition = partition_name
torch.fx.graph.map_arg(node.args, lambda def_node: record_cross_partition_use(def_node, node))
- torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950
+ torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950
# find partitions with no dependencies
root_partitions: List[str] = []
@@ -208,7 +209,7 @@ def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node
# Transform nodes and collect targets for partition's submodule
for node in m.graph.nodes:
- if hasattr(node, '_fx_partition'):
+ if hasattr(node, "_fx_partition"):
partition = partitions[node._fx_partition]
# swap out old graph nodes in kw/args with references to new nodes in this submodule
@@ -216,25 +217,24 @@ def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node
gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n])
gathered_kwargs = torch.fx.graph.map_arg(node.kwargs, lambda n: environment[n])
- if node.op not in ['call_module', 'get_attr']:
+ if node.op not in ["call_module", "get_attr"]:
target = node.target
else:
- target_atoms = node.target.split('.')
+ target_atoms = node.target.split(".")
target_attr = m
for atom in target_atoms:
if not hasattr(target_attr, atom):
- raise RuntimeError(f'Operator target {node.target} not found!')
+ raise RuntimeError(f"Operator target {node.target} not found!")
target_attr = getattr(target_attr, atom)
# target = target_atoms[-1]
- target = '_'.join(target_atoms)
+ target = "_".join(target_atoms)
partition.targets[target] = target_attr
assert isinstance(gathered_args, tuple)
assert isinstance(gathered_kwargs, dict)
- new_node = partition.graph.create_node(op=node.op,
- target=target,
- args=gathered_args,
- kwargs=gathered_kwargs)
+ new_node = partition.graph.create_node(
+ op=node.op, target=target, args=gathered_args, kwargs=gathered_kwargs
+ )
new_node.meta = node.meta.copy()
partition.environment[node] = new_node
@@ -243,14 +243,14 @@ def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node
base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {}
for node in m.graph.nodes:
- if node.op == 'placeholder':
- if version.parse(torch.__version__) < version.parse('1.11.0'):
+ if node.op == "placeholder":
+ if version.parse(torch.__version__) < version.parse("1.11.0"):
base_mod_env[node.name] = base_mod_graph.placeholder(node.target, type_expr=node.type)
else:
default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty
- base_mod_env[node.name] = base_mod_graph.placeholder(node.target,
- type_expr=node.type,
- default_value=default_value)
+ base_mod_env[node.name] = base_mod_graph.placeholder(
+ node.target, type_expr=node.type, default_value=default_value
+ )
base_mod_env[node.name].meta = node.meta.copy()
# Do some things iterating over the partitions in topological order again:
@@ -264,13 +264,14 @@ def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node
# Set correct output values
output_vals = tuple(partition.environment[orig_nodes[name]] for name in partition.outputs)
- output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment]
+ output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment]
partition.graph.output(output_vals)
# Construct GraphModule for this partition
- submod_name = f'submod_{partition_name}'
- base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(partition.targets,
- partition.graph) # noqa: B950
+ submod_name = f"submod_{partition_name}"
+ base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(
+ partition.targets, partition.graph
+ ) # noqa: B950
# Emit call in base graph to this submodule
output_val = base_mod_graph.call_module(submod_name, tuple(base_mod_env[name] for name in partition.inputs))
@@ -278,15 +279,15 @@ def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node
# Unpack multiple return values from submodule
output_val_proxy = torch.fx.proxy.Proxy(output_val)
for i, output_name in enumerate(partition.outputs):
- base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index]
+ base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index]
else:
if not partition.outputs:
continue
base_mod_env[list(partition.outputs)[0]] = output_val
for node in m.graph.nodes:
- if node.op == 'output':
- base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950
+ if node.op == "output":
+ base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950
for partition_name in sorted_partitions:
partition = partitions[partition_name]
diff --git a/colossalai/fx/passes/utils.py b/colossalai/fx/passes/utils.py
index bb4f3cd6a490..c51f49a30e8a 100644
--- a/colossalai/fx/passes/utils.py
+++ b/colossalai/fx/passes/utils.py
@@ -1,7 +1,9 @@
-import torch
from typing import Dict
-from torch.fx.node import Node, map_arg
+
+import torch
from torch.fx.graph import Graph
+from torch.fx.node import Node, map_arg
+
def get_comm_size(prev_partition, next_partition):
"""
@@ -23,7 +25,7 @@ def get_comm_size(prev_partition, next_partition):
map_arg(node.kwargs, lambda n: input_nodes.setdefault(n))
for n in input_nodes:
if n.name in parent_node_names and n not in visited_nodes:
- comm_size += n.meta['tensor_meta'].numel
+ comm_size += n.meta["tensor_meta"].numel
visited_nodes.add(n)
return comm_size
@@ -36,12 +38,12 @@ def get_leaf(graph: Graph):
"""
input_nodes: Dict[Node, None] = {}
for node in graph.nodes:
- if node.op == 'output':
+ if node.op == "output":
map_arg(node.args, lambda n: input_nodes.setdefault(n))
map_arg(node.kwargs, lambda n: input_nodes.setdefault(n))
placeholder_nodes = []
for node in input_nodes.keys():
- if node.op == 'placeholder':
+ if node.op == "placeholder":
placeholder_nodes.append(node)
for node in placeholder_nodes:
input_nodes.pop(node)
@@ -60,13 +62,13 @@ def get_top(graph: Graph):
"""
top_node_list = set()
for node in graph.nodes:
- if node.op == 'output':
+ if node.op == "output":
continue
is_top = False
def _get_top(node):
nonlocal is_top
- if node.op == 'placeholder':
+ if node.op == "placeholder":
is_top = True
map_arg(node.args, lambda n: _get_top(n))
@@ -83,7 +85,7 @@ def is_top(graph: Graph, node: Node):
def get_all_consumers(graph: Graph, node: Node):
"""
Given a graph and a node of this graph, return all consumers of the node.
-
+
Returns:
List of ``Nodes`` that node appear in these nodes ``args`` and ``kwargs``.
"""
@@ -120,7 +122,7 @@ def forward(self, x):
for node in gm.graph.nodes:
if hasattr(node, 'bfs_level'):
print(node.name, node.bfs_level)
-
+
Output:
graph():
%x : [#users=2] = placeholder[target=x]
@@ -148,7 +150,7 @@ def forward(self, x):
while nodes_to_process:
new_process_list = []
for node in nodes_to_process:
- if node.op == 'output':
+ if node.op == "output":
continue
node.bfs_level = current_level
new_process_list.extend(get_all_consumers(graph, node))
@@ -165,8 +167,9 @@ def get_node_module(node) -> torch.nn.Module:
torch.nn.Module: the module associated with the given node
"""
- assert node.graph.owning_module is not None, 'Cannot find the owning_module for node.graph, please make sure the graph is associated with a GraphModule object'
- assert node.op == 'call_module', f'Expected node.op to be call_module, but found {node.op}'
+ assert (
+ node.graph.owning_module is not None
+ ), "Cannot find the owning_module for node.graph, please make sure the graph is associated with a GraphModule object"
+ assert node.op == "call_module", f"Expected node.op to be call_module, but found {node.op}"
module = node.graph.owning_module.get_submodule(node.target)
return module
-
diff --git a/colossalai/fx/profiler/__init__.py b/colossalai/fx/profiler/__init__.py
index 8bcbde0eb23b..89dd2b3df617 100644
--- a/colossalai/fx/profiler/__init__.py
+++ b/colossalai/fx/profiler/__init__.py
@@ -12,7 +12,16 @@
)
from .tensor import MetaTensor
else:
- from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module, calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out
+ from .experimental import (
+ meta_profiler_function,
+ meta_profiler_module,
+ profile_function,
+ profile_method,
+ profile_module,
+ calculate_fwd_in,
+ calculate_fwd_tmp,
+ calculate_fwd_out,
+ )
from .dataflow import GraphInfo
from .memory_utils import activation_size, is_inplace, parameter_size
diff --git a/colossalai/fx/profiler/constants.py b/colossalai/fx/profiler/constants.py
index 5763a46dc83f..fad9bb272bff 100644
--- a/colossalai/fx/profiler/constants.py
+++ b/colossalai/fx/profiler/constants.py
@@ -1,6 +1,6 @@
import torch
-__all__ = ['ALIAS_ATEN', 'INPLACE_NEW', 'INPLACE_MATH_ATEN', 'CLONE_ATEN', 'RELU_LIKE_OPS', 'RELU_LIKE_MOD']
+__all__ = ["ALIAS_ATEN", "INPLACE_NEW", "INPLACE_MATH_ATEN", "CLONE_ATEN", "RELU_LIKE_OPS", "RELU_LIKE_MOD"]
aten = torch.ops.aten
diff --git a/colossalai/fx/profiler/dataflow.py b/colossalai/fx/profiler/dataflow.py
index a5e8880322b8..05f9b50ce575 100644
--- a/colossalai/fx/profiler/dataflow.py
+++ b/colossalai/fx/profiler/dataflow.py
@@ -1,6 +1,5 @@
from dataclasses import dataclass, field
from enum import Enum
-from functools import partial
from typing import Dict, List
from torch.fx import Graph, Node
@@ -69,8 +68,8 @@ class GraphInfo:
def is_phase(n: Node, phase: Phase) -> bool:
- assert 'phase' in n.meta, f'Node meta of {n} has no key `phase`!'
- return n.meta['phase'] == phase
+ assert "phase" in n.meta, f"Node meta of {n} has no key `phase`!"
+ return n.meta["phase"] == phase
@compatibility(is_backward_compatible=False)
@@ -103,9 +102,9 @@ def _peak_memory(deps: Dict[Node, int]):
peak_mem = 0
for k, v in deps.items():
if v > 0 and is_phase(k, Phase.BACKWARD) and not all(map(is_inplace, k.users)) and not is_inplace(k):
- peak_mem += activation_size(k.meta['saved_tensor'])
- if v <= float('-inf') and is_phase(k, Phase.FORWARD):
- peak_mem -= activation_size(k.meta['saved_tensor'])
+ peak_mem += activation_size(k.meta["saved_tensor"])
+ if v <= float("-inf") and is_phase(k, Phase.FORWARD):
+ peak_mem -= activation_size(k.meta["saved_tensor"])
return peak_mem
# deps is used to track all the memory dependencies of the graph.
@@ -123,19 +122,19 @@ def _peak_memory(deps: Dict[Node, int]):
# Otherwise, the tensor belongs to `fwd_mem_tmp`. If we checkpoint
# the node, `fwd_mem_tmp` can be freed.
if is_phase(n, Phase.PLACEHOLDER):
- graph_info.fwd_in += n.meta['saved_tensor']
+ graph_info.fwd_in += n.meta["saved_tensor"]
if is_phase(n, Phase.FORWARD):
- graph_info.fwd_tmp += n.meta['saved_tensor']
+ graph_info.fwd_tmp += n.meta["saved_tensor"]
elif is_phase(n, Phase.BACKWARD):
if len(n.users):
graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps))
else:
# TODO: some of the bwd_mem_out might be model parameters.
# basically a backward node without user is a `grad_out` node
- graph_info.bwd_mem_out += activation_size(n.meta['saved_tensor'])
+ graph_info.bwd_mem_out += activation_size(n.meta["saved_tensor"])
for input_n in n.all_input_nodes:
if input_n in deps:
deps[input_n] -= 1
if deps[input_n] <= 0:
- deps[input_n] = float('-inf')
+ deps[input_n] = float("-inf")
return graph_info
diff --git a/colossalai/fx/profiler/experimental/constants.py b/colossalai/fx/profiler/experimental/constants.py
index 57ff3fd91299..02758e7643af 100644
--- a/colossalai/fx/profiler/experimental/constants.py
+++ b/colossalai/fx/profiler/experimental/constants.py
@@ -2,7 +2,7 @@
import torch
-__all__ = ['INPLACE_OPS', 'INPLACE_METHOD', 'NON_INPLACE_METHOD']
+__all__ = ["INPLACE_OPS", "INPLACE_METHOD", "NON_INPLACE_METHOD"]
# TODO fill out the inplace ops
INPLACE_OPS = [
@@ -20,25 +20,25 @@
# TODO: list all call_methods that are inplace here
INPLACE_METHOD = [
- 'transpose',
- 'permute',
+ "transpose",
+ "permute",
# TODO: reshape may return a copy of the data if the data is not contiguous
- 'reshape',
- 'dim',
- 'flatten',
- 'size',
- 'view',
- 'unsqueeze',
- 'to',
- 'type',
- 'flatten',
+ "reshape",
+ "dim",
+ "flatten",
+ "size",
+ "view",
+ "unsqueeze",
+ "to",
+ "type",
+ "flatten",
]
# TODO: list all call_methods that are not inplace here
NON_INPLACE_METHOD = [
- 'chunk',
- 'contiguous',
- 'expand',
- 'mean',
- 'split',
+ "chunk",
+ "contiguous",
+ "expand",
+ "mean",
+ "split",
]
diff --git a/colossalai/fx/profiler/experimental/profiler.py b/colossalai/fx/profiler/experimental/profiler.py
index 5c545260e72b..d890fdb66fc2 100644
--- a/colossalai/fx/profiler/experimental/profiler.py
+++ b/colossalai/fx/profiler/experimental/profiler.py
@@ -9,7 +9,7 @@
from .constants import INPLACE_METHOD, INPLACE_OPS, NON_INPLACE_METHOD
from .registry import meta_profiler_function, meta_profiler_module
-__all__ = ['profile_function', 'profile_module', 'profile_method']
+__all__ = ["profile_function", "profile_module", "profile_method"]
# this is for compatibility use
@@ -42,6 +42,7 @@ class GraphInfo:
bwd_mem_tmp (int): See the above illustration.
bwd_mem_out (int): See the above illustration.
"""
+
fwd_flop: int = 0
bwd_flop: int = 0
fwd_mem_in: int = 0
@@ -50,8 +51,7 @@ class GraphInfo:
bwd_mem_out: int = 0
-CALL_FUNCTION_MSG = \
-"""
+CALL_FUNCTION_MSG = """
Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n
from colossalai.fx.profiler.experimental import meta_profiler_function
@meta_profiler_function.register(YOUR_FUNCTION)
@@ -60,9 +60,8 @@ def profile_YOUR_FUNCTION(input: torch.Tensor, *args) -> Tuple[int, int]:
macs = ...
return flops, macs
"""
-CALL_METHOD_MSG = 'Please check if {} is an inplace method. If so, add target to INPLACE_METHOD={}. Otherwise, add target to NON_INPLACE_METHOD={}'
-CALL_MODULE_MSG = \
-"""
+CALL_METHOD_MSG = "Please check if {} is an inplace method. If so, add target to INPLACE_METHOD={}. Otherwise, add target to NON_INPLACE_METHOD={}"
+CALL_MODULE_MSG = """
Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n
from colossalai.fx.profiler.experimental import meta_profiler_module
@meta_profiler_module.register(YOUR_MODULE)
@@ -74,7 +73,7 @@ def profile_YOUR_MODULE(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int
@compatibility(is_backward_compatible=True)
-def profile_function(target: 'Target') -> Callable:
+def profile_function(target: "Target") -> Callable:
"""
Wrap a `call_function` node or `torch.nn.functional` in order to
record the memory cost and FLOPs of the execution.
@@ -92,12 +91,13 @@ def profile_function(target: 'Target') -> Callable:
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
assert meta_profiler_function.has(target) or meta_profiler_function.has(
- target.__name__), CALL_FUNCTION_MSG.format(target)
+ target.__name__
+ ), CALL_FUNCTION_MSG.format(target)
fwd_tmp = 0
fwd_out = 0
out = func(*args, **kwargs)
- if target not in INPLACE_OPS and not kwargs.get('inplace', False):
+ if target not in INPLACE_OPS and not kwargs.get("inplace", False):
fwd_out = activation_size(out)
if meta_profiler_function.has(target):
profiler = meta_profiler_function.get(target)
@@ -112,7 +112,7 @@ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
@compatibility(is_backward_compatible=True)
-def profile_method(target: 'Target') -> Callable:
+def profile_method(target: "Target") -> Callable:
"""
Wrap a `call_method` node
record the memory cost and FLOPs of the execution.
@@ -126,11 +126,12 @@ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
self_obj, *args_tail = args
# execute the method and return the result
- assert isinstance(target, str), f'{target} instance is not str.'
+ assert isinstance(target, str), f"{target} instance is not str."
out = getattr(self_obj, target)(*args_tail, **kwargs)
assert target in INPLACE_METHOD + NON_INPLACE_METHOD, CALL_METHOD_MSG.format(
- target, INPLACE_METHOD, NON_INPLACE_METHOD)
+ target, INPLACE_METHOD, NON_INPLACE_METHOD
+ )
# call_method has no parameters and are MOSTLY(?) inplace, and has no FLOPs or MACs.
fwd_tmp = 0 if target in INPLACE_METHOD else activation_size(out)
fwd_out = 0 if target not in INPLACE_METHOD else activation_size(out)
@@ -161,7 +162,7 @@ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
fwd_tmp = 0
fwd_out = 0
out = func(*args, **kwargs)
- if getattr(module, 'inplace', False):
+ if getattr(module, "inplace", False):
fwd_out = activation_size(out)
profiler = meta_profiler_module.get(type(module))
fwd_flop, _ = profiler(module, *args, **kwargs)
diff --git a/colossalai/fx/profiler/experimental/profiler_function/activation_function.py b/colossalai/fx/profiler/experimental/profiler_function/activation_function.py
index a43aef063e19..c518ec28da41 100644
--- a/colossalai/fx/profiler/experimental/profiler_function/activation_function.py
+++ b/colossalai/fx/profiler/experimental/profiler_function/activation_function.py
@@ -1,5 +1,7 @@
from typing import Tuple
+
import torch
+
from ..registry import meta_profiler_function
# TODO: different activation has different FLOPs count, currently unused.
diff --git a/colossalai/fx/profiler/experimental/profiler_function/arithmetic.py b/colossalai/fx/profiler/experimental/profiler_function/arithmetic.py
index 8d1c8a8c6877..f1b9bb97c6c6 100644
--- a/colossalai/fx/profiler/experimental/profiler_function/arithmetic.py
+++ b/colossalai/fx/profiler/experimental/profiler_function/arithmetic.py
@@ -41,15 +41,15 @@ def _elementwise_flops_compute(input, other):
@meta_profiler_function.register(torch.sub)
@meta_profiler_function.register(torch.mul)
@meta_profiler_function.register(torch.floor_divide)
-@meta_profiler_function.register('add') # for built-in op +
-@meta_profiler_function.register('iadd') # for built-in op +=
-@meta_profiler_function.register('eq') # for built-in op =
-@meta_profiler_function.register('sub') # for built-in op -
-@meta_profiler_function.register('isub') # for built-in op -=
-@meta_profiler_function.register('mul') # for built-in op *
-@meta_profiler_function.register('imul') # for built-in op *=
-@meta_profiler_function.register('floordiv') # for built-in op //
-@meta_profiler_function.register('ifloordiv') # for built-in op //=
+@meta_profiler_function.register("add") # for built-in op +
+@meta_profiler_function.register("iadd") # for built-in op +=
+@meta_profiler_function.register("eq") # for built-in op =
+@meta_profiler_function.register("sub") # for built-in op -
+@meta_profiler_function.register("isub") # for built-in op -=
+@meta_profiler_function.register("mul") # for built-in op *
+@meta_profiler_function.register("imul") # for built-in op *=
+@meta_profiler_function.register("floordiv") # for built-in op //
+@meta_profiler_function.register("ifloordiv") # for built-in op //=
def torch_add_like_ops(input: Any, other: Any, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
return _elementwise_flops_compute(input, other)
@@ -62,7 +62,7 @@ def torch_elementwise_op(input: torch.Tensor, *, out: Optional[torch.Tensor] = N
@meta_profiler_function.register(torch.matmul)
-@meta_profiler_function.register('matmul') # for built-in op @
+@meta_profiler_function.register("matmul") # for built-in op @
@meta_profiler_function.register(torch.Tensor.matmul)
def torch_matmul(input: torch.Tensor, other: torch.Tensor, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
macs = reduce(operator.mul, input.shape) * other.shape[-1]
@@ -78,13 +78,15 @@ def torch_bmm(input: torch.Tensor, other: torch.Tensor, *, out: Optional[torch.T
@meta_profiler_function.register(torch.var_mean)
-def torch_var_mean(input: torch.Tensor,
- dim: Union[int, Tuple[int, ...]],
- unbiased: Optional[bool] = True,
- keepdim: Optional[bool] = False,
- *,
- out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
- assert out is None, 'saving to out is not supported yet'
+def torch_var_mean(
+ input: torch.Tensor,
+ dim: Union[int, Tuple[int, ...]],
+ unbiased: Optional[bool] = True,
+ keepdim: Optional[bool] = False,
+ *,
+ out: Optional[torch.Tensor] = None,
+) -> Tuple[int, int]:
+ assert out is None, "saving to out is not supported yet"
flops = input.numel() * 3
macs = 0
return flops, macs
diff --git a/colossalai/fx/profiler/experimental/profiler_function/embedding.py b/colossalai/fx/profiler/experimental/profiler_function/embedding.py
index d6e43d781b8b..1d362015fc8b 100644
--- a/colossalai/fx/profiler/experimental/profiler_function/embedding.py
+++ b/colossalai/fx/profiler/experimental/profiler_function/embedding.py
@@ -1,5 +1,7 @@
-import torch
from typing import Optional
+
+import torch
+
from ..registry import meta_profiler_function
diff --git a/colossalai/fx/profiler/experimental/profiler_function/linear.py b/colossalai/fx/profiler/experimental/profiler_function/linear.py
index 01fe4c871370..ecc578d61b91 100644
--- a/colossalai/fx/profiler/experimental/profiler_function/linear.py
+++ b/colossalai/fx/profiler/experimental/profiler_function/linear.py
@@ -1,5 +1,7 @@
from typing import Tuple
+
import torch
+
from ..registry import meta_profiler_function
diff --git a/colossalai/fx/profiler/experimental/profiler_function/normalization.py b/colossalai/fx/profiler/experimental/profiler_function/normalization.py
index c4ea508d70f8..2ad029eda039 100644
--- a/colossalai/fx/profiler/experimental/profiler_function/normalization.py
+++ b/colossalai/fx/profiler/experimental/profiler_function/normalization.py
@@ -1,5 +1,7 @@
from typing import List, Optional, Tuple
+
import torch
+
from ..registry import meta_profiler_function
@@ -21,11 +23,13 @@ def torch_nn_func_instancenorm(
@meta_profiler_function.register(torch.nn.functional.group_norm)
-def torch_nn_func_groupnorm(input: torch.Tensor,
- num_groups: int,
- weight: Optional[torch.Tensor] = None,
- bias: Optional[torch.Tensor] = None,
- eps: float = 1e-5) -> Tuple[int, int]:
+def torch_nn_func_groupnorm(
+ input: torch.Tensor,
+ num_groups: int,
+ weight: Optional[torch.Tensor] = None,
+ bias: Optional[torch.Tensor] = None,
+ eps: float = 1e-5,
+) -> Tuple[int, int]:
has_affine = weight is not None
flops = input.numel() * (5 if has_affine else 4)
macs = 0
diff --git a/colossalai/fx/profiler/experimental/profiler_function/pooling.py b/colossalai/fx/profiler/experimental/profiler_function/pooling.py
index a639f5ee83c1..c91deab906d4 100644
--- a/colossalai/fx/profiler/experimental/profiler_function/pooling.py
+++ b/colossalai/fx/profiler/experimental/profiler_function/pooling.py
@@ -1,5 +1,7 @@
-from typing import Tuple, Union
+from typing import Tuple
+
import torch
+
from ..registry import meta_profiler_function
diff --git a/colossalai/fx/profiler/experimental/profiler_function/python_ops.py b/colossalai/fx/profiler/experimental/profiler_function/python_ops.py
index 1e8561206ba0..58c9889ad98e 100644
--- a/colossalai/fx/profiler/experimental/profiler_function/python_ops.py
+++ b/colossalai/fx/profiler/experimental/profiler_function/python_ops.py
@@ -1,6 +1,6 @@
import operator
from typing import Any, Tuple
-import torch
+
from ..registry import meta_profiler_function
diff --git a/colossalai/fx/profiler/experimental/profiler_function/torch_ops.py b/colossalai/fx/profiler/experimental/profiler_function/torch_ops.py
index abdd7ad565ba..67e90fb69acd 100644
--- a/colossalai/fx/profiler/experimental/profiler_function/torch_ops.py
+++ b/colossalai/fx/profiler/experimental/profiler_function/torch_ops.py
@@ -1,7 +1,9 @@
-from functools import reduce
import operator
+from functools import reduce
from typing import Any, Optional, Tuple
+
import torch
+
from ..registry import meta_profiler_function
@@ -43,13 +45,11 @@ def torch_where(condition: torch.Tensor, x: Any, y: Any) -> Tuple[int, int]:
@meta_profiler_function.register(torch.max)
-def torch_max(input: torch.Tensor,
- dim: int = None,
- keepdim: bool = False,
- *,
- out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
+def torch_max(
+ input: torch.Tensor, dim: int = None, keepdim: bool = False, *, out: Optional[torch.Tensor] = None
+) -> Tuple[int, int]:
macs = 0
- assert out is None, 'assigning value to out is not supported yet'
+ assert out is None, "assigning value to out is not supported yet"
if dim is not None:
shape = list(input.shape)
shape.pop(int(dim))
diff --git a/colossalai/fx/profiler/experimental/profiler_module/activation_function.py b/colossalai/fx/profiler/experimental/profiler_module/activation_function.py
index 2ebf514ad269..ae065e0c7c17 100644
--- a/colossalai/fx/profiler/experimental/profiler_module/activation_function.py
+++ b/colossalai/fx/profiler/experimental/profiler_module/activation_function.py
@@ -1,5 +1,7 @@
from typing import Tuple
+
import torch
+
from ..registry import meta_profiler_module
# TODO: different activation has different FLOPs count, currently unused.
diff --git a/colossalai/fx/profiler/experimental/profiler_module/attention.py b/colossalai/fx/profiler/experimental/profiler_module/attention.py
index 8daf74b232bf..dfaee75e0432 100644
--- a/colossalai/fx/profiler/experimental/profiler_module/attention.py
+++ b/colossalai/fx/profiler/experimental/profiler_module/attention.py
@@ -1,19 +1,23 @@
from typing import Optional, Tuple
+
import torch
+
from ..registry import meta_profiler_module
# TODO: This is hard to compute memory cost
@meta_profiler_module.register(torch.nn.MultiheadAttention)
-def torch_nn_msa(self: torch.nn.MultiheadAttention,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- key_padding_mask: Optional[torch.Tensor] = None,
- need_weights: bool = True,
- attn_mask: Optional[torch.Tensor] = None,
- average_attn_weights: bool = True) -> Tuple[int, int]:
- if getattr(self, 'batch_first', False):
+def torch_nn_msa(
+ self: torch.nn.MultiheadAttention,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ key_padding_mask: Optional[torch.Tensor] = None,
+ need_weights: bool = True,
+ attn_mask: Optional[torch.Tensor] = None,
+ average_attn_weights: bool = True,
+) -> Tuple[int, int]:
+ if getattr(self, "batch_first", False):
batch_size = query.shape[0]
len_idx = 1
else:
@@ -44,15 +48,9 @@ def torch_nn_msa(self: torch.nn.MultiheadAttention,
flops += qlen * qdim
# Initial projections
- flops += 2 * ((qlen * qdim * qdim) # QW
- + (klen * kdim * kdim) # KW
- + (vlen * vdim * vdim) # VW
- )
+ flops += 2 * ((qlen * qdim * qdim) + (klen * kdim * kdim) + (vlen * vdim * vdim)) # QW # KW # VW
- macs += ((qlen * qdim * qdim) # QW
- + (klen * kdim * kdim) # KW
- + (vlen * vdim * vdim) # VW
- )
+ macs += (qlen * qdim * qdim) + (klen * kdim * kdim) + (vlen * vdim * vdim) # QW # KW # VW
if self.in_proj_bias is not None:
flops += (qlen + klen + vlen) * qdim
@@ -62,13 +60,9 @@ def torch_nn_msa(self: torch.nn.MultiheadAttention,
v_head_dim = vdim // num_heads
head_flops = (
- 2 * (qlen * klen * qk_head_dim) # QK^T
- + (qlen * klen) # softmax
- + 2 * (qlen * klen * v_head_dim) # AV
+ 2 * (qlen * klen * qk_head_dim) + (qlen * klen) + 2 * (qlen * klen * v_head_dim) # QK^T # softmax # AV
)
- head_macs = ((qlen * klen * qk_head_dim) # QK^T
- + 2 * (qlen * klen * v_head_dim) # AV
- )
+ head_macs = (qlen * klen * qk_head_dim) + 2 * (qlen * klen * v_head_dim) # QK^T # AV
flops += num_heads * head_flops
macs += num_heads * head_flops
diff --git a/colossalai/fx/profiler/experimental/profiler_module/convolution.py b/colossalai/fx/profiler/experimental/profiler_module/convolution.py
index a4c15b91e611..90e494c77f5b 100644
--- a/colossalai/fx/profiler/experimental/profiler_module/convolution.py
+++ b/colossalai/fx/profiler/experimental/profiler_module/convolution.py
@@ -17,8 +17,9 @@ def torch_nn_conv1d(self: torch.nn.Conv1d, input: torch.Tensor) -> Tuple[int, in
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
c_in, l_in = input.shape[-2:]
c_out = self.out_channels
- l_out = math.floor((l_in + 2 * self.padding[0] - self.dilation[0] *
- (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
+ l_out = math.floor(
+ (l_in + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
+ )
result_shape = input.shape[:-2] + (
c_out,
l_out,
@@ -38,10 +39,12 @@ def torch_nn_conv2d(self: torch.nn.Conv2d, input: torch.Tensor) -> Tuple[int, in
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
c_in, h_in, w_in = input.shape[-3:]
c_out = self.out_channels
- h_out = math.floor((h_in + 2 * self.padding[0] - self.dilation[0] *
- (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
- w_out = math.floor((w_in + 2 * self.padding[1] - self.dilation[1] *
- (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1)
+ h_out = math.floor(
+ (h_in + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
+ )
+ w_out = math.floor(
+ (w_in + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1
+ )
result_shape = input.shape[:-3] + (
c_out,
h_out,
@@ -62,12 +65,15 @@ def torch_nn_conv3d(self: torch.nn.Conv3d, input: torch.Tensor) -> Tuple[int, in
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html
c_in, d_in, h_in, w_in = input.shape[-4:]
c_out = self.out_channels
- d_out = math.floor((d_in + 2 * self.padding[0] - self.dilation[0] *
- (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
- h_out = math.floor((h_in + 2 * self.padding[1] - self.dilation[1] *
- (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1)
- w_out = math.floor((w_in + 2 * self.padding[2] - self.dilation[2] *
- (self.kernel_size[2] - 1) - 1) / self.stride[2] + 1)
+ d_out = math.floor(
+ (d_in + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
+ )
+ h_out = math.floor(
+ (h_in + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1
+ )
+ w_out = math.floor(
+ (w_in + 2 * self.padding[2] - self.dilation[2] * (self.kernel_size[2] - 1) - 1) / self.stride[2] + 1
+ )
result_shape = input.shape[:-4] + (
c_out,
d_out,
@@ -89,8 +95,13 @@ def torch_nn_convtranspose1d(self: torch.nn.ConvTranspose1d, input: torch.Tensor
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html
c_in, l_in = input.shape[-2:]
c_out = self.out_channels
- l_out = math.floor((l_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] *
- (self.kernel_size[0] - 1) + self.output_padding[0] + 1)
+ l_out = math.floor(
+ (l_in - 1) * self.stride[0]
+ - 2 * self.padding[0]
+ + self.dilation[0] * (self.kernel_size[0] - 1)
+ + self.output_padding[0]
+ + 1
+ )
result_shape = input.shape[:-2] + (
c_out,
l_out,
@@ -98,7 +109,7 @@ def torch_nn_convtranspose1d(self: torch.nn.ConvTranspose1d, input: torch.Tensor
macs_per_elem = reduce(operator.mul, self.kernel_size) * c_in // self.groups
num_elem = reduce(
operator.mul, input.shape
- ) # see https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L604
+ ) # see https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L604
macs = macs_per_elem * num_elem
flops = 2 * macs
if self.bias is not None:
@@ -112,10 +123,20 @@ def torch_nn_convtranspose2d(self: torch.nn.ConvTranspose2d, input: torch.Tensor
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
c_in, h_in, w_in = input.shape[-3:]
c_out = self.out_channels
- h_out = math.floor((h_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] *
- (self.kernel_size[0] - 1) + self.output_padding[0] + 1)
- w_out = math.floor((w_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] *
- (self.kernel_size[1] - 1) + self.output_padding[1] + 1)
+ h_out = math.floor(
+ (h_in - 1) * self.stride[0]
+ - 2 * self.padding[0]
+ + self.dilation[0] * (self.kernel_size[0] - 1)
+ + self.output_padding[0]
+ + 1
+ )
+ w_out = math.floor(
+ (w_in - 1) * self.stride[1]
+ - 2 * self.padding[1]
+ + self.dilation[1] * (self.kernel_size[1] - 1)
+ + self.output_padding[1]
+ + 1
+ )
result_shape = input.shape[:-3] + (
c_out,
h_out,
@@ -136,12 +157,27 @@ def torch_nn_convtranspose3d(self: torch.nn.ConvTranspose3d, input: torch.Tensor
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose3d.html
c_in, d_in, h_in, w_in = input.shape[-4:]
c_out = self.out_channels
- d_out = math.floor((d_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] *
- (self.kernel_size[0] - 1) + self.output_padding[0] + 1)
- h_out = math.floor((h_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] *
- (self.kernel_size[1] - 1) + self.output_padding[1] + 1)
- w_out = math.floor((w_in - 1) * self.stride[2] - 2 * self.padding[2] + self.dilation[2] *
- (self.kernel_size[2] - 1) + self.output_padding[2] + 1)
+ d_out = math.floor(
+ (d_in - 1) * self.stride[0]
+ - 2 * self.padding[0]
+ + self.dilation[0] * (self.kernel_size[0] - 1)
+ + self.output_padding[0]
+ + 1
+ )
+ h_out = math.floor(
+ (h_in - 1) * self.stride[1]
+ - 2 * self.padding[1]
+ + self.dilation[1] * (self.kernel_size[1] - 1)
+ + self.output_padding[1]
+ + 1
+ )
+ w_out = math.floor(
+ (w_in - 1) * self.stride[2]
+ - 2 * self.padding[2]
+ + self.dilation[2] * (self.kernel_size[2] - 1)
+ + self.output_padding[2]
+ + 1
+ )
result_shape = input.shape[:-4] + (
c_out,
d_out,
diff --git a/colossalai/fx/profiler/experimental/profiler_module/dropout.py b/colossalai/fx/profiler/experimental/profiler_module/dropout.py
index 417e0ed46863..7361239eb1bd 100644
--- a/colossalai/fx/profiler/experimental/profiler_module/dropout.py
+++ b/colossalai/fx/profiler/experimental/profiler_module/dropout.py
@@ -1,5 +1,7 @@
from typing import Tuple
+
import torch
+
from ..registry import meta_profiler_module
diff --git a/colossalai/fx/profiler/experimental/profiler_module/linear.py b/colossalai/fx/profiler/experimental/profiler_module/linear.py
index e1ffb6f244d2..71fed3196c13 100644
--- a/colossalai/fx/profiler/experimental/profiler_module/linear.py
+++ b/colossalai/fx/profiler/experimental/profiler_module/linear.py
@@ -1,5 +1,7 @@
from typing import Tuple
+
import torch
+
from ..registry import meta_profiler_module
diff --git a/colossalai/fx/profiler/experimental/profiler_module/normalization.py b/colossalai/fx/profiler/experimental/profiler_module/normalization.py
index 49e5e6fa5384..5a64e44947b7 100644
--- a/colossalai/fx/profiler/experimental/profiler_module/normalization.py
+++ b/colossalai/fx/profiler/experimental/profiler_module/normalization.py
@@ -16,8 +16,12 @@
@meta_profiler_module.register(torch.nn.BatchNorm1d)
@meta_profiler_module.register(torch.nn.BatchNorm2d)
@meta_profiler_module.register(torch.nn.BatchNorm3d)
-def torch_nn_normalize(self: Union[torch.nn.LayerNorm, torch.nn.GroupNorm, torch.nn.BatchNorm1d, torch.nn.BatchNorm2d,
- torch.nn.BatchNorm3d], input: torch.Tensor) -> Tuple[int, int]:
+def torch_nn_normalize(
+ self: Union[
+ torch.nn.LayerNorm, torch.nn.GroupNorm, torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d
+ ],
+ input: torch.Tensor,
+) -> Tuple[int, int]:
# adopted from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L615
has_affine = self.weight is not None
if self.training:
@@ -30,6 +34,7 @@ def torch_nn_normalize(self: Union[torch.nn.LayerNorm, torch.nn.GroupNorm, torch
try:
import apex
+
meta_profiler_module.register(apex.normalization.FusedLayerNorm)(torch_nn_normalize)
meta_profiler_module.register(apex.normalization.FusedRMSNorm)(torch_nn_normalize)
meta_profiler_module.register(apex.normalization.MixedFusedLayerNorm)(torch_nn_normalize)
diff --git a/colossalai/fx/profiler/experimental/profiler_module/pooling.py b/colossalai/fx/profiler/experimental/profiler_module/pooling.py
index e429ac3eea28..b3b630b2dee9 100644
--- a/colossalai/fx/profiler/experimental/profiler_module/pooling.py
+++ b/colossalai/fx/profiler/experimental/profiler_module/pooling.py
@@ -1,5 +1,7 @@
from typing import Tuple
+
import torch
+
from ..registry import meta_profiler_module
diff --git a/colossalai/fx/profiler/experimental/profiler_module/rnn.py b/colossalai/fx/profiler/experimental/profiler_module/rnn.py
index 6e733d6da915..8a4c828dbd27 100644
--- a/colossalai/fx/profiler/experimental/profiler_module/rnn.py
+++ b/colossalai/fx/profiler/experimental/profiler_module/rnn.py
@@ -1,12 +1,15 @@
-from functools import reduce
import operator
+from functools import reduce
+from typing import Optional, Tuple
+
import torch
+
from ..registry import meta_profiler_module
-from typing import Optional, Tuple, Union
-def _rnn_flops(flops: int, macs: int, module: torch.nn.RNNBase, w_ih: torch.Tensor,
- w_hh: torch.Tensor) -> Tuple[int, int]:
+def _rnn_flops(
+ flops: int, macs: int, module: torch.nn.RNNBase, w_ih: torch.Tensor, w_hh: torch.Tensor
+) -> Tuple[int, int]:
# copied from https://github.com/sovrasov/flops-counter.pytorch/blob/master/ptflops/pytorch_ops.py
# matrix matrix mult ih state and internal state
@@ -42,12 +45,12 @@ def torch_nn_rnn(self: torch.nn.RNNBase, input: torch.Tensor, hx: Optional[torch
flops = 0
macs = 0
for i in range(self.num_layers):
- w_ih = self.__getattr__('weight_ih_l' + str(i))
- w_hh = self.__getattr__('weight_hh_l' + str(i))
+ w_ih = self.__getattr__("weight_ih_l" + str(i))
+ w_hh = self.__getattr__("weight_hh_l" + str(i))
flops, macs = _rnn_flops(flops, macs, self, w_ih, w_hh)
if self.bias:
- b_ih = self.__getattr__('bias_ih_l' + str(i))
- b_hh = self.__getattr__('bias_hh_l' + str(i))
+ b_ih = self.__getattr__("bias_ih_l" + str(i))
+ b_hh = self.__getattr__("bias_hh_l" + str(i))
flops += reduce(operator.mul, b_ih) + reduce(operator.mul, b_hh)
flops *= reduce(operator.mul, input.shape[:2])
macs *= reduce(operator.mul, input.shape[:2])
@@ -63,12 +66,12 @@ def torch_nn_rnn(self: torch.nn.RNNBase, input: torch.Tensor, hx: Optional[torch
def torch_nn_rnn(self: torch.nn.RNNCellBase, input: torch.Tensor, hx: Optional[torch.Tensor] = None) -> Tuple[int, int]:
flops = 0
macs = 0
- w_ih = self.__getattr__('weight_ih_l')
- w_hh = self.__getattr__('weight_hh_l')
+ w_ih = self.__getattr__("weight_ih_l")
+ w_hh = self.__getattr__("weight_hh_l")
flops, macs = _rnn_flops(flops, macs, self, w_ih, w_hh)
if self.bias:
- b_ih = self.__getattr__('bias_ih_l')
- b_hh = self.__getattr__('bias_hh_l')
+ b_ih = self.__getattr__("bias_ih_l")
+ b_hh = self.__getattr__("bias_hh_l")
flops += reduce(operator.mul, b_ih) + reduce(operator.mul, b_hh)
flops *= input.shape[0]
macs *= input.shape[0]
diff --git a/colossalai/fx/profiler/experimental/profiler_module/torch_op.py b/colossalai/fx/profiler/experimental/profiler_module/torch_op.py
index d3aed874eb10..06be25246a71 100644
--- a/colossalai/fx/profiler/experimental/profiler_module/torch_op.py
+++ b/colossalai/fx/profiler/experimental/profiler_module/torch_op.py
@@ -1,7 +1,8 @@
-import operator
+from typing import Tuple
+
import torch
+
from ..registry import meta_profiler_module
-from typing import Optional, Tuple, Union
@meta_profiler_module.register(torch.nn.Flatten)
diff --git a/colossalai/fx/profiler/experimental/registry.py b/colossalai/fx/profiler/experimental/registry.py
index 7d73bce321e4..d47129cd2978 100644
--- a/colossalai/fx/profiler/experimental/registry.py
+++ b/colossalai/fx/profiler/experimental/registry.py
@@ -1,11 +1,9 @@
class ProfilerRegistry:
-
def __init__(self, name):
self.name = name
self.store = {}
def register(self, source):
-
def wrapper(func):
self.store[source] = func
return func
@@ -21,5 +19,5 @@ def has(self, source):
return source in self.store
-meta_profiler_function = ProfilerRegistry(name='patched_functions_for_meta_profile')
-meta_profiler_module = ProfilerRegistry(name='patched_modules_for_meta_profile')
+meta_profiler_function = ProfilerRegistry(name="patched_functions_for_meta_profile")
+meta_profiler_module = ProfilerRegistry(name="patched_modules_for_meta_profile")
diff --git a/colossalai/fx/profiler/experimental/shard_utils.py b/colossalai/fx/profiler/experimental/shard_utils.py
index 1e53ed0bf8ec..90e8c3b7cfe4 100644
--- a/colossalai/fx/profiler/experimental/shard_utils.py
+++ b/colossalai/fx/profiler/experimental/shard_utils.py
@@ -1,8 +1,6 @@
# for PyTorch 1.11 compatibility uses
-from typing import Dict, List, Tuple, Union
-import torch
-from torch.fx import GraphModule, Node
+from torch.fx import Node
from ..._compatibility import compatibility
@@ -19,7 +17,7 @@ def calculate_fwd_in(n: Node) -> bool:
Returns:
save_fwd_in (bool): the result of `save_fwd_in`
"""
- return n.meta['save_fwd_in']
+ return n.meta["save_fwd_in"]
@compatibility(is_backward_compatible=True)
@@ -45,4 +43,4 @@ def calculate_fwd_out(n: Node) -> int:
Returns:
fwd_out (int): the result of `fwd_out`
"""
- return n.meta['fwd_mem_out']
+ return n.meta["fwd_mem_out"]
diff --git a/colossalai/fx/profiler/memory_utils.py b/colossalai/fx/profiler/memory_utils.py
index 6ccbcb01cdc1..e8eb5f25cb6c 100644
--- a/colossalai/fx/profiler/memory_utils.py
+++ b/colossalai/fx/profiler/memory_utils.py
@@ -1,11 +1,11 @@
from typing import Dict, List, Tuple, Union
import torch
-from torch.fx import GraphModule, Node
+from torch.fx import Node
from .._compatibility import compatibility, is_compatible_with_meta
-__all__ = ['activation_size', 'parameter_size', 'is_inplace']
+__all__ = ["activation_size", "parameter_size", "is_inplace"]
@compatibility(is_backward_compatible=True)
@@ -63,6 +63,7 @@ def is_inplace(n: Node):
inplace = n.kwargs.get("inplace", False)
if is_compatible_with_meta():
from .constants import ALIAS_ATEN
+
if n.target in ALIAS_ATEN:
inplace = True
elif n.op == "call_module":
diff --git a/colossalai/fx/profiler/opcount.py b/colossalai/fx/profiler/opcount.py
index ba090a2ec51b..8fae0f2ecb45 100644
--- a/colossalai/fx/profiler/opcount.py
+++ b/colossalai/fx/profiler/opcount.py
@@ -173,8 +173,11 @@ def norm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
# Inputs[0] contains the shape of the input.
input_shape = inputs[input_arg_index].shape
- has_affine = inputs[affine_arg_index].shape is not None if hasattr(inputs[affine_arg_index],
- 'shape') else inputs[affine_arg_index]
+ has_affine = (
+ inputs[affine_arg_index].shape is not None
+ if hasattr(inputs[affine_arg_index], "shape")
+ else inputs[affine_arg_index]
+ )
assert 2 <= len(input_shape) <= 5, input_shape
# 5 is just a rough estimate
flop = reduce(operator.mul, input_shape) * (5 if has_affine else 4)
@@ -188,7 +191,7 @@ def batchnorm_flop_jit(inputs: List[Any], outputs: List[Any], training: bool = N
training = inputs[-3]
assert isinstance(training, bool), "Signature of aten::batch_norm has changed!"
if training:
- return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore
+ return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore
has_affine = inputs[1].shape is not None
input_shape = reduce(operator.mul, inputs[0].shape)
return input_shape * (2 if has_affine else 1)
@@ -218,15 +221,16 @@ def elementwise_flop(inputs: List[Any], outputs: List[Any]) -> Number:
def zero_flop_jit(*args):
"""
- Count flops for zero flop layers.
+ Count flops for zero flop layers.
"""
return 0
-if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
- torch.__version__) < version.parse('2.0.0'):
+if version.parse(torch.__version__) >= version.parse("1.12.0") and version.parse(torch.__version__) < version.parse(
+ "2.0.0"
+):
flop_mapping = {
- # gemm, gemv and dot
+ # gemm, gemv and dot
aten.mm.default: matmul_flop_jit,
aten.mv.default: matmul_flop_jit,
aten.dot.default: matmul_flop_jit,
@@ -234,13 +238,11 @@ def zero_flop_jit(*args):
aten.addmm.default: addmm_flop_jit,
aten.bmm.default: bmm_flop_jit,
aten.baddbmm.default: baddbmm_flop_jit,
-
- # convolution
+ # convolution
aten.convolution.default: conv_flop_jit,
aten._convolution.default: conv_flop_jit,
aten.convolution_backward.default: conv_backward_flop_jit,
-
- # normalization
+ # normalization
aten.native_batch_norm.default: batchnorm_flop_jit,
aten.native_batch_norm_backward.default: batchnorm_flop_jit,
aten.cudnn_batch_norm.default: batchnorm_flop_jit,
@@ -249,8 +251,7 @@ def zero_flop_jit(*args):
aten.native_layer_norm_backward.default: norm_flop_counter(2, 0),
aten.native_group_norm.default: norm_flop_counter(2, 0),
aten.native_group_norm_backward.default: norm_flop_counter(2, 0),
-
- # pooling
+ # pooling
aten.avg_pool1d.default: elementwise_flop_counter(1, 0),
aten.avg_pool2d.default: elementwise_flop_counter(1, 0),
aten.avg_pool2d_backward.default: elementwise_flop_counter(0, 1),
@@ -275,7 +276,7 @@ def zero_flop_jit(*args):
}
elementwise_flop_aten = [
- # basic op
+ # basic op
aten.add.Tensor,
aten.add_.Tensor,
aten.div.Tensor,
@@ -296,8 +297,7 @@ def zero_flop_jit(*args):
aten.exp.default,
aten.sin.default,
aten.cos.default,
-
- # activation op
+ # activation op
aten.hardswish.default,
aten.hardswish_.default,
aten.hardswish_backward.default,
@@ -320,8 +320,7 @@ def zero_flop_jit(*args):
aten.tanh.default,
aten.tanh_backward.default,
aten.threshold_backward.default,
-
- # dropout
+ # dropout
aten.native_dropout.default,
aten.native_dropout_backward.default,
]
@@ -362,7 +361,7 @@ def zero_flop_jit(*args):
aten.zero_.default,
aten.zeros_like.default,
aten.fill_.Scalar,
- aten.stack.default
+ aten.stack.default,
] # yapf: disable
for op in zero_flop_aten:
diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py
index c87cd4321d31..97e70db6290e 100644
--- a/colossalai/fx/profiler/profiler.py
+++ b/colossalai/fx/profiler/profiler.py
@@ -15,7 +15,7 @@
from .opcount import flop_mapping
from .tensor import MetaTensor
-__all__ = ['profile_function', 'profile_module', 'profile_method']
+__all__ = ["profile_function", "profile_module", "profile_method"]
# super-dainiu: this cache should be global, otherwise it cannot
# track duplicated tensors between nodes
@@ -174,7 +174,6 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
# backward is executed.
# Hopefully, this attempt will provide a better estimation of memory.
class FlopTensor(MetaTensor):
-
_node: Node = None
def __repr__(self):
@@ -186,24 +185,24 @@ def __repr__(self):
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
args_node = tree_map(lambda x: x._node if isinstance(x, FlopTensor) else None, args)
kwargs_node = tree_map(lambda x: x._node if isinstance(x, FlopTensor) else None, kwargs)
- node = subgraph.create_node('call_function', func, args_node, kwargs_node)
+ node = subgraph.create_node("call_function", func, args_node, kwargs_node)
out = super().__torch_dispatch__(func, types, args, kwargs)
flop_count[phase] += flop_mapping[func](args, normalize_tuple(out))
- node.meta['phase'] = phase
+ node.meta["phase"] = phase
# super-dainiu: in `nn.MultiheadAttention` this weird thing occurs,
# i.e. `Phase.PLACEHOLDER` tensors are aliased and saved during
# `Phase.FORWARD`
if phase == Phase.FORWARD:
if all(map(partial(is_phase, phase=Phase.PLACEHOLDER), node.all_input_nodes)) and func in ALIAS_ATEN:
- node.meta['phase'] = Phase.PLACEHOLDER
+ node.meta["phase"] = Phase.PLACEHOLDER
# TODO(yby): specify `saved_tensors` for backward memory estimation
- node.meta['saved_tensor'] = []
+ node.meta["saved_tensor"] = []
if phase == Phase.BACKWARD:
- node.meta['saved_tensor'] = normalize_tuple(out)
+ node.meta["saved_tensor"] = normalize_tuple(out)
def wrap(x):
if isinstance(x, MetaTensor):
@@ -219,11 +218,14 @@ def wrap(x):
x = FlopTensor(x)
if is_autogradable(x):
x.requires_grad_(True)
- x._node = subgraph.create_node('placeholder',
- 'placeholder', (subgraph._root,),
- name=subgraph._graph_namespace.create_name('input', x._tensor))
- x._node.meta['phase'] = Phase.PLACEHOLDER
- x._node.meta['saved_tensor'] = []
+ x._node = subgraph.create_node(
+ "placeholder",
+ "placeholder",
+ (subgraph._root,),
+ name=subgraph._graph_namespace.create_name("input", x._tensor),
+ )
+ x._node.meta["phase"] = Phase.PLACEHOLDER
+ x._node.meta["saved_tensor"] = []
return x
# Basically, we need to detach the args and kwargs from the outer graph.
@@ -235,7 +237,7 @@ def pack(x):
if isinstance(x, FlopTensor) and not x._tensor.data_ptr() in cache:
tensor = x._tensor.detach()
tensor.data_ptr = x._tensor.data_ptr
- x._node.meta['saved_tensor'] += [tensor]
+ x._node.meta["saved_tensor"] += [tensor]
if not do_not_cache:
cache.add(x._tensor.data_ptr())
return x
@@ -284,7 +286,7 @@ def unwrap(x):
@compatibility(is_backward_compatible=True)
-def profile_function(target: 'Target', device: str = 'meta') -> Callable:
+def profile_function(target: "Target", device: str = "meta") -> Callable:
"""
Wrap a `call_function` node or `torch.nn.functional` in order to
record the memory cost and FLOPs of the execution.
@@ -300,7 +302,6 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable:
"""
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
-
# find the grad for parameter in args and kwargs
param_size = 0
@@ -316,18 +317,18 @@ def get_param_size(x):
# still run the profiling but discard some results regarding `target`
global do_not_cache
- inplace = kwargs.get('inplace', False)
+ inplace = kwargs.get("inplace", False)
if target in OUTPUT_SAVED_OPS:
do_not_cache = True
if inplace:
do_not_cache = True
- kwargs['inplace'] = False
- if device == 'meta':
+ kwargs["inplace"] = False
+ if device == "meta":
out, meta = _profile_meta(func, *args, **kwargs)
else:
out, meta = _profile_concrete(func, *args, **kwargs)
if inplace:
- kwargs['inplace'] = True
+ kwargs["inplace"] = True
meta.bwd_mem_tmp = 0
meta.bwd_mem_out = 0
do_not_cache = False
@@ -341,7 +342,7 @@ def get_param_size(x):
@compatibility(is_backward_compatible=True)
-def profile_method(target: 'Target', device: str = 'meta') -> Callable:
+def profile_method(target: "Target", device: str = "meta") -> Callable:
"""
Wrap a `call_method` node
record the memory cost and FLOPs of the execution.
@@ -349,8 +350,8 @@ def profile_method(target: 'Target', device: str = 'meta') -> Callable:
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
# execute the method and return the result
- assert isinstance(target, str), f'{target} instance is not str.'
- if device == 'meta':
+ assert isinstance(target, str), f"{target} instance is not str."
+ if device == "meta":
out, meta = _profile_meta(target, *args, **kwargs)
else:
out, meta = _profile_concrete(target, *args, **kwargs)
@@ -360,7 +361,7 @@ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
@compatibility(is_backward_compatible=True)
-def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable:
+def profile_module(module: torch.nn.Module, device: str = "meta") -> Callable:
"""
Wrap a `call_module` node or `torch.nn` in order to
record the memory cost and FLOPs of the execution.
@@ -376,7 +377,6 @@ def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable:
"""
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
-
# calculate parameter size
param_size = parameter_size(module)
@@ -384,13 +384,13 @@ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
# still run the profiling but discard some results regarding `module`.
global do_not_cache
- inplace = getattr(module, 'inplace', False)
+ inplace = getattr(module, "inplace", False)
if type(module) in OUTPUT_SAVED_MOD:
do_not_cache = True
if inplace:
do_not_cache = True
module.inplace = False
- if device == 'meta':
+ if device == "meta":
out, meta = _profile_meta(func, *args, **kwargs)
else:
out, meta = _profile_concrete(func, *args, **kwargs)
diff --git a/colossalai/fx/profiler/shard_utils.py b/colossalai/fx/profiler/shard_utils.py
index 34feefb4336a..75b7c814f05f 100644
--- a/colossalai/fx/profiler/shard_utils.py
+++ b/colossalai/fx/profiler/shard_utils.py
@@ -59,9 +59,9 @@ def forward(self, input_2):
Returns:
bool: Whether the node is a ReLU-like node
"""
- if n.op == 'call_function':
+ if n.op == "call_function":
return n.target in OUTPUT_SAVED_OPS
- elif n.op == 'call_module':
+ elif n.op == "call_module":
return type(n.graph.owning_module.get_submodule(n.target)) in OUTPUT_SAVED_MOD
return False
diff --git a/colossalai/fx/profiler/tensor.py b/colossalai/fx/profiler/tensor.py
index 2ee5e5c47750..7c14b48bdaa1 100644
--- a/colossalai/fx/profiler/tensor.py
+++ b/colossalai/fx/profiler/tensor.py
@@ -1,13 +1,13 @@
import uuid
import torch
-from torch.types import _bool, _device, _dtype
-from torch.utils._pytree import tree_flatten, tree_map
+from torch.types import _device
+from torch.utils._pytree import tree_map
from .._compatibility import compatibility
from .constants import ALIAS_ATEN
-__all__ = ['MetaTensor']
+__all__ = ["MetaTensor"]
def set_data_ptr(x):
@@ -43,12 +43,13 @@ def __new__(cls, elem, fake_device=None):
storage_offset=elem.storage_offset(),
dtype=elem.dtype,
layout=elem.layout,
- device=fake_device or (elem.device if elem.device.type != 'meta' else torch.device('cpu')),
- requires_grad=elem.requires_grad) # deceive the frontend for aten selections
+ device=fake_device or (elem.device if elem.device.type != "meta" else torch.device("cpu")),
+ requires_grad=elem.requires_grad,
+ ) # deceive the frontend for aten selections
r._tensor = elem
# ...the real tensor is held as an element on the tensor.
if not r._tensor.is_meta:
- r._tensor = r._tensor.to(torch.device('meta'))
+ r._tensor = r._tensor.to(torch.device("meta"))
# only tensor not on `meta` should be copied to `meta`
set_data_ptr(r._tensor)
return r
@@ -69,15 +70,15 @@ def unwrap(x):
x = x._tensor
elif isinstance(x, torch.Tensor):
fake_device = x.device
- x = x.to(torch.device('meta'))
+ x = x.to(torch.device("meta"))
return x
args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
- if 'device' in kwargs:
- fake_device = kwargs['device']
- kwargs['device'] = torch.device('meta')
+ if "device" in kwargs:
+ fake_device = kwargs["device"]
+ kwargs["device"] = torch.device("meta")
# run aten for backend=CPU but actually on backend=Meta
out = func(*args, **kwargs)
@@ -93,7 +94,7 @@ def wrap(x):
if isinstance(x, torch.Tensor):
nonlocal fake_device
if not x.is_meta:
- x = x.to(torch.device('meta'))
+ x = x.to(torch.device("meta"))
return MetaTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x
return tree_map(wrap, out)
@@ -120,18 +121,18 @@ def replace(x):
nonlocal fake_device
if isinstance(x, str) or isinstance(x, _device):
fake_device = x
- return 'meta'
+ return "meta"
return x
elem = self._tensor.to(*tree_map(replace, args), **tree_map(replace, kwargs))
return MetaTensor(elem, fake_device=fake_device)
def cpu(self, *args, **kwargs):
- if self.device.type == 'cpu':
+ if self.device.type == "cpu":
return self.to(*args, **kwargs)
- return self.to(*args, device='cpu', **kwargs)
+ return self.to(*args, device="cpu", **kwargs)
def cuda(self, device=None, non_blocking=False):
if device is not None:
return self.to(device=device, non_blocking=non_blocking)
- return self.to(device='cuda:0', non_blocking=non_blocking)
+ return self.to(device="cuda:0", non_blocking=non_blocking)
diff --git a/colossalai/fx/proxy.py b/colossalai/fx/proxy.py
index 7317072c6298..887832223fd6 100644
--- a/colossalai/fx/proxy.py
+++ b/colossalai/fx/proxy.py
@@ -1,12 +1,11 @@
-import operator
-from typing import Any, List, Union
+from typing import Any
import torch
-from torch.fx.proxy import Attribute, Proxy
+from torch.fx.proxy import Proxy
from colossalai.fx.tracer.meta_patch import meta_patched_function
-__all__ = ['ColoProxy']
+__all__ = ["ColoProxy"]
class ColoProxy(Proxy):
@@ -39,11 +38,12 @@ def has_meta_data(self):
return self._meta_data is not None
def _assert_meta_data_is_tensor(self):
- assert torch.is_tensor(
- self._meta_data) and self._meta_data.is_meta, f'Meta data is not a meta tensor for {self.node.name}'
+ assert (
+ torch.is_tensor(self._meta_data) and self._meta_data.is_meta
+ ), f"Meta data is not a meta tensor for {self.node.name}"
def _assert_has_meta_data(self):
- assert self._meta_data is not None, f'Meta data is not set for {self.node.name}'
+ assert self._meta_data is not None, f"Meta data is not set for {self.node.name}"
def __len__(self):
self._assert_has_meta_data()
@@ -62,7 +62,6 @@ def __bool__(self):
return self.meta_data
def __getattr__(self, k):
-
return ColoAttribute(self, k)
def __contains__(self, key):
@@ -92,7 +91,6 @@ def _convert(val):
class ColoAttribute(ColoProxy):
-
def __init__(self, root, attr: str):
self.root = root
self.attr = attr
diff --git a/colossalai/fx/tracer/_meta_trace.py b/colossalai/fx/tracer/_meta_trace.py
index 1c5abb81d271..63a7bab654d5 100644
--- a/colossalai/fx/tracer/_meta_trace.py
+++ b/colossalai/fx/tracer/_meta_trace.py
@@ -39,7 +39,7 @@ class MetaProxy(torch.Tensor):
_tensor: torch.Tensor
_node: Node
- __slots__ = ['_tensor', '_node']
+ __slots__ = ["_tensor", "_node"]
@staticmethod
def __new__(cls, tensor, fake_device=None, placeholder=False, name=None):
@@ -51,22 +51,22 @@ def __new__(cls, tensor, fake_device=None, placeholder=False, name=None):
dtype=tensor.dtype,
layout=tensor.layout,
device=fake_device if fake_device is not None else tensor.device,
- requires_grad=tensor.requires_grad) # deceive the frontend for aten selections
+ requires_grad=tensor.requires_grad,
+ ) # deceive the frontend for aten selections
r._tensor = tensor
if placeholder:
if name is None:
- name = 'input'
- r._node = graph.create_node('placeholder',
- 'placeholder', (graph._root,),
- name=namespace.create_name(name, tensor))
+ name = "input"
+ r._node = graph.create_node(
+ "placeholder", "placeholder", (graph._root,), name=namespace.create_name(name, tensor)
+ )
# ...the real tensor is held as an element on the tensor.
if not r._tensor.is_meta:
- r._tensor = r._tensor.to(torch.device('meta'))
+ r._tensor = r._tensor.to(torch.device("meta"))
return r
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
-
def unwrap(x):
nonlocal fake_device
if isinstance(x, MetaProxy):
@@ -75,21 +75,21 @@ def unwrap(x):
# assert not isinstance(x, MetaProxy)
elif isinstance(x, torch.Tensor):
fake_device = x.device
- x = x.to(torch.device('meta'))
+ x = x.to(torch.device("meta"))
return x
def get_node(x):
- if isinstance(x, torch.Tensor) and not hasattr(x, '_node'):
- x = MetaProxy(x, placeholder=True, name='weight')
- return x if not hasattr(x, '_node') else x._node
+ if isinstance(x, torch.Tensor) and not hasattr(x, "_node"):
+ x = MetaProxy(x, placeholder=True, name="weight")
+ return x if not hasattr(x, "_node") else x._node
args_node = tree_map(get_node, args)
kwargs_node = tree_map(get_node, kwargs)
- node = graph.create_node('call_function', func, args_node, kwargs_node)
+ node = graph.create_node("call_function", func, args_node, kwargs_node)
- if 'device' in kwargs:
- fake_device = kwargs['device']
- kwargs['device'] = torch.device('meta')
+ if "device" in kwargs:
+ fake_device = kwargs["device"]
+ kwargs["device"] = torch.device("meta")
args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
@@ -103,9 +103,12 @@ def wrap(x):
if isinstance(x, torch.Tensor):
nonlocal fake_device
if not x.is_meta:
- x = x.to(torch.device('meta'))
- return MetaProxy(
- x, fake_device=fake_device) if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor') else x
+ x = x.to(torch.device("meta"))
+ return (
+ MetaProxy(x, fake_device=fake_device)
+ if isinstance(x, torch.Tensor) and not hasattr(x, "_tensor")
+ else x
+ )
def set_node(x):
x._node = node
@@ -125,9 +128,12 @@ def wrap(x):
for tensor in normalize_tuple(out):
if is_autogradable(tensor) and tensor.requires_grad:
- grad = torch.empty_like(tensor._tensor, device=torch.device('meta')) if isinstance(
- tensor, MetaProxy) else torch.empty_like(tensor, device=torch.device('meta'))
- torch.autograd.backward(tensor,
- MetaProxy(grad, fake_device=tensor.device, placeholder=True),
- retain_graph=True)
+ grad = (
+ torch.empty_like(tensor._tensor, device=torch.device("meta"))
+ if isinstance(tensor, MetaProxy)
+ else torch.empty_like(tensor, device=torch.device("meta"))
+ )
+ torch.autograd.backward(
+ tensor, MetaProxy(grad, fake_device=tensor.device, placeholder=True), retain_graph=True
+ )
return graph
diff --git a/colossalai/fx/tracer/_tracer_utils.py b/colossalai/fx/tracer/_tracer_utils.py
index e160497a7444..9cf1961d45ff 100644
--- a/colossalai/fx/tracer/_tracer_utils.py
+++ b/colossalai/fx/tracer/_tracer_utils.py
@@ -2,10 +2,10 @@
import torch
-from ..proxy import ColoAttribute, ColoProxy
-from .meta_patch import meta_patched_function, meta_patched_module
+from ..proxy import ColoProxy
+from .meta_patch import meta_patched_function
-__all__ = ['is_element_in_list', 'extract_meta']
+__all__ = ["is_element_in_list", "extract_meta"]
def is_element_in_list(elements: Union[List[Any], Any], list_: List[Any]):
@@ -21,7 +21,6 @@ def is_element_in_list(elements: Union[List[Any], Any], list_: List[Any]):
def extract_meta(*args, **kwargs):
-
def _convert(val):
if isinstance(val, ColoProxy):
return val.meta_data
diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addbmm.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addbmm.py
index 859a19bf6241..84c09109877e 100644
--- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addbmm.py
+++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addbmm.py
@@ -1,7 +1,4 @@
-import operator
-
import torch
-import torch.nn.functional as F
from ...registry import bias_addition_function, bias_addition_method
from .bias_addition_function import LinearBasedBiasFunc
@@ -10,13 +7,12 @@
@bias_addition_method.register(torch.Tensor.addbmm)
@bias_addition_function.register(torch.addbmm)
class Addbmm(LinearBasedBiasFunc):
-
def extract_kwargs_from_origin_func(self):
kwargs = {}
- if 'beta' in self.kwargs:
- kwargs['beta'] = self.kwargs['beta']
- if 'alpha' in self.kwargs:
- kwargs['alpha'] = self.kwargs['alpha']
+ if "beta" in self.kwargs:
+ kwargs["beta"] = self.kwargs["beta"]
+ if "alpha" in self.kwargs:
+ kwargs["alpha"] = self.kwargs["alpha"]
return kwargs
def create_non_bias_func_proxy(self, input_proxy, other_proxy):
@@ -25,7 +21,7 @@ def create_non_bias_func_proxy(self, input_proxy, other_proxy):
compute the main computation, such as convolution, with bias option banned.
"""
assert self.substitute_func == torch.bmm
- node_kind = 'call_function'
+ node_kind = "call_function"
node_target = self.substitute_func
node_args = (input_proxy, other_proxy)
@@ -35,10 +31,10 @@ def create_non_bias_func_proxy(self, input_proxy, other_proxy):
return non_bias_func_proxy
def insert_sum_node(self, input_proxy, sum_dims=0):
- '''
+ """
This method is used to sum the input_proxy through the sum_dims.
- '''
- node_kind = 'call_function'
+ """
+ node_kind = "call_function"
node_target = torch.sum
node_args = (input_proxy, sum_dims)
node_kwargs = {}
@@ -55,15 +51,15 @@ def generate(self):
sum_proxy = self.insert_sum_node(non_bias_linear_func_proxy)
kwargs = self.extract_kwargs_from_origin_func()
- if 'beta' in kwargs:
- beta = kwargs['beta']
+ if "beta" in kwargs:
+ beta = kwargs["beta"]
# doing the multiplication with beta if it exists(temp_2 = beta * input)
beta_proxy = self.create_mul_node(self.args[0], beta)
else:
beta_proxy = self.args[0]
- if 'alpha' in kwargs:
- alpha = kwargs['alpha']
+ if "alpha" in kwargs:
+ alpha = kwargs["alpha"]
# doing the multiplication with alpha if it exists(temp_3 = alpha * temp_1)
alpha_proxy = self.create_mul_node(alpha, sum_proxy)
else:
diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addmm.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addmm.py
index fe7d8d07aac9..d087b2913005 100644
--- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addmm.py
+++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addmm.py
@@ -1,7 +1,4 @@
-import operator
-
import torch
-import torch.nn.functional as F
from ...registry import bias_addition_function, bias_addition_method
from .bias_addition_function import LinearBasedBiasFunc
@@ -10,17 +7,16 @@
@bias_addition_method.register(torch.Tensor.addmm)
@bias_addition_function.register(torch.addmm)
class Addmm(LinearBasedBiasFunc):
-
def extract_kwargs_from_origin_func(self):
kwargs = {}
- if 'beta' in self.kwargs:
- kwargs['beta'] = self.kwargs['beta']
- if 'alpha' in self.kwargs:
- kwargs['alpha'] = self.kwargs['alpha']
+ if "beta" in self.kwargs:
+ kwargs["beta"] = self.kwargs["beta"]
+ if "alpha" in self.kwargs:
+ kwargs["alpha"] = self.kwargs["alpha"]
return kwargs
def transpose_other_operand_for_linear(self, other_proxy):
- '''
+ """
This method is used to transpose the other operand for linear function.
For example:
input = torch.rand(3, 4)
@@ -30,8 +26,8 @@ def transpose_other_operand_for_linear(self, other_proxy):
# To keep the computation graph consistent with the origin computation graph, we need to transpose the m2
# before we call the linear function.
new_output = torch.linear(m1, m2.transpose(0, 1)) + input
- '''
- node_kind = 'call_function'
+ """
+ node_kind = "call_function"
node_target = torch.transpose
node_args = (other_proxy, 0, 1)
node_kwargs = {}
@@ -43,14 +39,14 @@ def generate(self):
non_bias_linear_func_proxy = self.create_non_bias_func_proxy(self.args[1], transpose_proxy)
kwargs = self.extract_kwargs_from_origin_func()
- if 'beta' in kwargs:
- beta = kwargs['beta']
+ if "beta" in kwargs:
+ beta = kwargs["beta"]
beta_proxy = self.create_mul_node(self.args[0], beta)
else:
beta_proxy = self.args[0]
- if 'alpha' in kwargs:
- alpha = kwargs['alpha']
+ if "alpha" in kwargs:
+ alpha = kwargs["alpha"]
alpha_proxy = self.create_mul_node(alpha, non_bias_linear_func_proxy)
else:
alpha_proxy = non_bias_linear_func_proxy
diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/bias_addition_function.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/bias_addition_function.py
index 8a3786332c08..42178b7b786e 100644
--- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/bias_addition_function.py
+++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/bias_addition_function.py
@@ -29,7 +29,6 @@ def extract_kwargs_from_origin_func(self):
to insert two more operator.mul nodes for the computation graph to compute the
final result.
"""
- pass
@abstractmethod
def generate(self):
@@ -50,7 +49,6 @@ def generate(self):
%mul_1 : [#users=1] = call_function[target=operator.mul](args = (2, %linear), kwargs = {})
%add : [#users=1] = call_function[target=operator.add](args = (%mul_1, %mul), kwargs = {})
"""
- pass
def create_mul_node(self, input_proxy, coefficent):
"""
@@ -59,7 +57,7 @@ def create_mul_node(self, input_proxy, coefficent):
Therefore, we need to use this method insert two more operator.mul nodes for
the computation graph to compute the final result.
"""
- node_kind = 'call_function'
+ node_kind = "call_function"
node_target = operator.mul
node_args = (
input_proxy,
@@ -82,7 +80,7 @@ def create_non_bias_func_proxy(self, input_proxy, other_proxy):
compute the main computation, such as convolution, with bias option banned.
"""
assert self.substitute_func == torch.nn.functional.linear
- node_kind = 'call_function'
+ node_kind = "call_function"
node_target = self.substitute_func
node_args = (input_proxy, other_proxy)
@@ -96,7 +94,7 @@ def create_bias_addition_proxy(self, non_bias_func_proxy, bias_proxy):
This method is used to create the bias_addition_proxy, the node created by this proxy will
compute the sum of non_bias_func result and bias with some reshape operation if needed.
"""
- bias_add_node_kind = 'call_function'
+ bias_add_node_kind = "call_function"
bias_add_node_target = operator.add
bias_add_args = (non_bias_func_proxy, bias_proxy)
bias_add_proxy = self.tracer.create_proxy(bias_add_node_kind, bias_add_node_target, tuple(bias_add_args), {})
diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/linear.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/linear.py
index e11ec0a364f1..ed060a350739 100644
--- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/linear.py
+++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/linear.py
@@ -1,6 +1,3 @@
-import operator
-
-import torch
import torch.nn.functional as F
from ...registry import bias_addition_function
@@ -9,17 +6,16 @@
@bias_addition_function.register(F.linear)
class Linear(LinearBasedBiasFunc):
-
def extract_kwargs_from_origin_func(self):
- assert 'bias' in self.kwargs
+ assert "bias" in self.kwargs
kwargs = {}
- if 'bias' in self.kwargs:
- kwargs['bias'] = self.kwargs['bias']
+ if "bias" in self.kwargs:
+ kwargs["bias"] = self.kwargs["bias"]
return kwargs
def generate(self):
non_bias_linear_func_proxy = self.create_non_bias_func_proxy(self.args[0], self.args[1])
kwargs = self.extract_kwargs_from_origin_func()
- bias_addition_proxy = self.create_bias_addition_proxy(non_bias_linear_func_proxy, kwargs['bias'])
+ bias_addition_proxy = self.create_bias_addition_proxy(non_bias_linear_func_proxy, kwargs["bias"])
return bias_addition_proxy
diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/bias_addition_module.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/bias_addition_module.py
index 591485fdb1ca..19c0e21d7c17 100644
--- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/bias_addition_module.py
+++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/bias_addition_module.py
@@ -27,8 +27,8 @@ def _create_weight_proxy(self):
Note: this function will be invoked during module initializing,
you should never call this function.
"""
- weight_node_kind = 'get_attr'
- weight_node_target = self.target + '.weight'
+ weight_node_kind = "get_attr"
+ weight_node_target = self.target + ".weight"
weight_proxy = self.tracer.create_proxy(weight_node_kind, weight_node_target, (), {})
return weight_proxy
@@ -39,8 +39,8 @@ def _create_bias_proxy(self):
Note: this function will be invoked during module initializing,
you should never call this function.
"""
- bias_node_kind = 'get_attr'
- bias_node_target = self.target + '.bias'
+ bias_node_kind = "get_attr"
+ bias_node_target = self.target + ".bias"
bias_proxy = self.tracer.create_proxy(bias_node_kind, bias_node_target, (), {})
return bias_proxy
@@ -54,14 +54,13 @@ def extract_kwargs_from_mod(self):
considered during module initializing. However, we need to consider those attributes as kwargs
in F.conv2d.
"""
- pass
def create_non_bias_func_proxy(self, input_proxy=None):
"""
This method is used to create the non_bias_func proxy, the node created by this proxy will
compute the main computation, such as convolution, with bias option banned.
"""
- node_kind = 'call_function'
+ node_kind = "call_function"
node_target = self.substitute_func
if input_proxy is None:
input_proxy = self.args[0]
@@ -75,7 +74,7 @@ def create_bias_addition_proxy(self, non_bias_func_proxy, bias_proxy):
This method is used to create the bias_addition_proxy, the node created by this proxy will
compute the sum of non_bias_func result and bias with some reshape operation if needed.
"""
- bias_add_node_kind = 'call_function'
+ bias_add_node_kind = "call_function"
bias_add_node_target = operator.add
bias_add_args = (non_bias_func_proxy, bias_proxy)
bias_add_proxy = self.tracer.create_proxy(bias_add_node_kind, bias_add_node_target, tuple(bias_add_args), {})
@@ -100,7 +99,6 @@ def generate(self):
%view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {})
%add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {})
"""
- pass
module_to_func_dict = {
diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py
index 4b6c82a74f57..812a141c1eab 100644
--- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py
+++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py
@@ -1,6 +1,5 @@
import torch
-import torch.nn.functional as F
-from torch.nn.modules.utils import _pair, _reverse_repeat_tuple, _single, _triple
+from torch.nn.modules.utils import _pair, _single, _triple
from ...registry import bias_addition_module
from .bias_addition_module import BiasAdditionModule
@@ -10,17 +9,16 @@
@bias_addition_module.register(torch.nn.Conv2d)
@bias_addition_module.register(torch.nn.Conv3d)
class BiasAdditionConv(BiasAdditionModule):
-
def extract_kwargs_from_mod(self):
root = self.tracer.root
conv_module = root.get_submodule(self.target)
- kwarg_attributes = ['groups', 'dilation', 'stride']
+ kwarg_attributes = ["groups", "dilation", "stride"]
non_bias_kwargs = {}
for attr_name in kwarg_attributes:
if hasattr(conv_module, attr_name):
non_bias_kwargs[attr_name] = getattr(conv_module, attr_name)
if conv_module.padding_mode != "zeros":
- #TODO: non zeros mode requires some extra processing for input
+ # TODO: non zeros mode requires some extra processing for input
conv_type = type(conv_module)
if conv_type == "torch.nn.Conv1d":
padding_element = _single(0)
@@ -28,9 +26,9 @@ def extract_kwargs_from_mod(self):
padding_element = _pair(0)
elif conv_type == "torch.nn.Conv3d":
padding_element = _triple(0)
- non_bias_kwargs['padding'] = padding_element
+ non_bias_kwargs["padding"] = padding_element
else:
- non_bias_kwargs['padding'] = getattr(conv_module, 'padding')
+ non_bias_kwargs["padding"] = getattr(conv_module, "padding")
return non_bias_kwargs
@@ -41,11 +39,12 @@ def create_bias_reshape_proxy(self, dimensions):
"""
bias_shape = [1] * (dimensions - 1)
bias_shape[0] = -1
- bias_reshape_node_kind = 'call_method'
- bias_reshape_node_target = 'view'
+ bias_reshape_node_kind = "call_method"
+ bias_reshape_node_target = "view"
bias_reshape_node_args = (self.bias_proxy, torch.Size(bias_shape))
- bias_reshape_proxy = self.tracer.create_proxy(bias_reshape_node_kind, bias_reshape_node_target,
- bias_reshape_node_args, {})
+ bias_reshape_proxy = self.tracer.create_proxy(
+ bias_reshape_node_kind, bias_reshape_node_target, bias_reshape_node_args, {}
+ )
return bias_reshape_proxy
def generate(self):
diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/linear.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/linear.py
index f6f7b6ddab40..b397f009846c 100644
--- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/linear.py
+++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/linear.py
@@ -1,5 +1,4 @@
import torch
-import torch.nn.functional as F
from ...registry import bias_addition_module
from .bias_addition_module import BiasAdditionModule
@@ -7,7 +6,6 @@
@bias_addition_module.register(torch.nn.Linear)
class BiasAdditionLinear(BiasAdditionModule):
-
def extract_kwargs_from_mod(self):
return {}
diff --git a/colossalai/fx/tracer/experimental.py b/colossalai/fx/tracer/experimental.py
index 22a67d1ceccc..e6e511b72fbb 100644
--- a/colossalai/fx/tracer/experimental.py
+++ b/colossalai/fx/tracer/experimental.py
@@ -1,4 +1,3 @@
-import enum
import functools
import inspect
import operator
@@ -10,7 +9,7 @@
from torch.utils._pytree import tree_map
from colossalai.fx import ColoGraphModule, compatibility, is_compatible_with_meta
-from colossalai.fx.tracer._tracer_utils import extract_meta, is_element_in_list
+from colossalai.fx.tracer._tracer_utils import is_element_in_list
from colossalai.fx.tracer.bias_addition_patch import func_to_func_dict, method_to_func_dict, module_to_func_dict
from colossalai.fx.tracer.registry import (
bias_addition_function,
@@ -24,31 +23,45 @@
from colossalai.fx.profiler import MetaTensor
Target = Union[Callable[..., Any], str]
-Argument = Optional[Union[Tuple[Any, ...], # actually Argument, but mypy can't represent recursive types
- List[Any], # actually Argument
- Dict[str, Any], # actually Argument
- slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing
- 'Node',]]
-_CScriptMethod = ['add', 'mul', 'sub', 'div']
+Argument = Optional[
+ Union[
+ Tuple[Any, ...], # actually Argument, but mypy can't represent recursive types
+ List[Any], # actually Argument
+ Dict[str, Any], # actually Argument
+ slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing
+ "Node",
+ ]
+]
+_CScriptMethod = ["add", "mul", "sub", "div"]
_TorchNewMethod = [
- "arange", "zeros", "zeros_like", "ones", "ones_like", "full", "full_like", "empty", "empty_like", "eye", "tensor",
- "finfo"
+ "arange",
+ "zeros",
+ "zeros_like",
+ "ones",
+ "ones_like",
+ "full",
+ "full_like",
+ "empty",
+ "empty_like",
+ "eye",
+ "tensor",
+ "finfo",
]
_TensorPropertyMethod = ["dtype", "shape", "device", "requires_grad", "grad", "grad_fn", "data"]
def _truncate_suffix(s: str):
import re
- return re.sub(r'_\d+$', '', s)
+
+ return re.sub(r"_\d+$", "", s)
def default_device():
- return torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
+ return torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
@compatibility(is_backward_compatible=False)
class ColoProxy(Proxy):
-
def __init__(self, *args, data=None, **kwargs):
super().__init__(*args, **kwargs)
self._meta_data = data
@@ -100,7 +113,7 @@ def __getattr__(self, k):
return ColoAttribute(self, k, getattr(self._meta_data, k, None))
def __setitem__(self, key, value):
- proxy = self.tracer.create_proxy('call_function', operator.setitem, (self, key, value), {})
+ proxy = self.tracer.create_proxy("call_function", operator.setitem, (self, key, value), {})
proxy.meta_data = self._meta_data
return proxy
@@ -125,29 +138,28 @@ def ndim(self):
@property
def device(self):
- proxy = self.tracer.create_proxy('call_function', getattr, (self, 'device'), {})
+ proxy = self.tracer.create_proxy("call_function", getattr, (self, "device"), {})
proxy.meta_data = self.meta_data.device
return proxy
@property
def dtype(self):
- proxy = self.tracer.create_proxy('call_function', getattr, (self, 'dtype'), {})
+ proxy = self.tracer.create_proxy("call_function", getattr, (self, "dtype"), {})
proxy.meta_data = self.meta_data.dtype
return proxy
def to(self, *args, **kwargs):
- return self.tracer.create_proxy('call_method', 'to', (self, *args), {**kwargs})
+ return self.tracer.create_proxy("call_method", "to", (self, *args), {**kwargs})
def cpu(self, *args, **kwargs):
- return self.tracer.create_proxy('call_method', 'cpu', (self, *args), {**kwargs})
+ return self.tracer.create_proxy("call_method", "cpu", (self, *args), {**kwargs})
def cuda(self, *args, **kwargs):
- return self.tracer.create_proxy('call_method', 'cuda', (self, *args), {**kwargs})
+ return self.tracer.create_proxy("call_method", "cuda", (self, *args), {**kwargs})
@compatibility(is_backward_compatible=False)
class ColoAttribute(ColoProxy):
-
def __init__(self, root, attr: str, data=None):
self.root = root
self.attr = attr
@@ -160,11 +172,11 @@ def node(self):
# the node for attributes is added lazily, since most will just be method calls
# which do not rely on the getitem call
if self._node is None:
- self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node
+ self._node = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {}).node
return self._node
def __call__(self, *args, **kwargs):
- return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs)
+ return self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs)
def __repr__(self):
return f"ColoAttribute({self.node.name}, attr={self.attr})"
@@ -172,7 +184,6 @@ def __repr__(self):
@compatibility(is_backward_compatible=False)
class ColoTracer(Tracer):
-
def __init__(self, trace_act_ckpt: bool = False, *args, **kwargs):
super().__init__(*args, **kwargs)
self._disable_module_getattr = False
@@ -184,24 +195,28 @@ def __init__(self, trace_act_ckpt: bool = False, *args, **kwargs):
self.inside_torch_checkpoint_func = False
self.act_ckpt_region_count = 0
- def proxy(self, node: Node) -> 'ColoProxy':
+ def proxy(self, node: Node) -> "ColoProxy":
return ColoProxy(node, self)
- def create_proxy(self,
- kind: str,
- target: Target,
- args: Tuple[Any, ...],
- kwargs: Dict[str, Any],
- name: Optional[str] = None,
- type_expr: Optional[Any] = None,
- proxy_factory_fn: Callable[[Node], 'Proxy'] = None):
-
+ def create_proxy(
+ self,
+ kind: str,
+ target: Target,
+ args: Tuple[Any, ...],
+ kwargs: Dict[str, Any],
+ name: Optional[str] = None,
+ type_expr: Optional[Any] = None,
+ proxy_factory_fn: Callable[[Node], "Proxy"] = None,
+ ):
proxy: ColoProxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p
- if kind == 'placeholder':
- proxy.meta_data = self.meta_args[target] if target in self.meta_args else self.concrete_args.get(
- _truncate_suffix(target), None)
- elif kind == 'get_attr':
+ if kind == "placeholder":
+ proxy.meta_data = (
+ self.meta_args[target]
+ if target in self.meta_args
+ else self.concrete_args.get(_truncate_suffix(target), None)
+ )
+ elif kind == "get_attr":
self._disable_module_getattr = True
try:
attr_itr = self.root
@@ -211,20 +226,21 @@ def create_proxy(self,
proxy.meta_data = attr_itr
finally:
self._disable_module_getattr = False
- elif kind == 'call_function':
+ elif kind == "call_function":
proxy.meta_data = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
- elif kind == 'call_method':
+ elif kind == "call_method":
self._disable_module_getattr = True
try:
- if target == '__call__':
+ if target == "__call__":
proxy.meta_data = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs))
else:
if target not in _TensorPropertyMethod:
- proxy._meta_data = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]),
- **tree_map(unwrap_fn, kwargs))
+ proxy._meta_data = getattr(unwrap_fn(args[0]), target)(
+ *tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs)
+ )
finally:
self._disable_module_getattr = False
- elif kind == 'call_module':
+ elif kind == "call_module":
mod = self.root.get_submodule(target)
self._disable_module_getattr = True
try:
@@ -238,14 +254,15 @@ def create_node(self, *args, **kwargs) -> Node:
if self.inside_torch_checkpoint_func:
# annotate the activation checkpoint module
- node.meta['activation_checkpoint'] = self.act_ckpt_region_count
+ node.meta["activation_checkpoint"] = self.act_ckpt_region_count
return node
- def trace(self,
- root: torch.nn.Module,
- concrete_args: Optional[Dict[str, torch.Tensor]] = None,
- meta_args: Optional[Dict[str, torch.Tensor]] = None) -> Graph:
-
+ def trace(
+ self,
+ root: torch.nn.Module,
+ concrete_args: Optional[Dict[str, torch.Tensor]] = None,
+ meta_args: Optional[Dict[str, torch.Tensor]] = None,
+ ) -> Graph:
if meta_args is None:
meta_args = {}
@@ -260,20 +277,19 @@ def trace(self,
# update concrete args with default values
non_meta_arg_names = sig_names - meta_arg_names
for k, v in sig.parameters.items():
- if k in non_meta_arg_names and \
- k not in concrete_args and \
- v.default is not inspect.Parameter.empty:
+ if k in non_meta_arg_names and k not in concrete_args and v.default is not inspect.Parameter.empty:
concrete_args[k] = v.default
# get non concrete arg names
concrete_arg_names = set(concrete_args.keys())
- non_concrete_arg_names = sig_names - concrete_arg_names
+ sig_names - concrete_arg_names
def _check_arg_name_valid(names):
success, element = is_element_in_list(names, sig_names)
if not success:
raise KeyError(
- f"argument {element} is not found in the signature of {root.__class__.__name__}'s forward function")
+ f"argument {element} is not found in the signature of {root.__class__.__name__}'s forward function"
+ )
_check_arg_name_valid(meta_arg_names)
_check_arg_name_valid(concrete_arg_names)
@@ -292,7 +308,6 @@ def trace_activation_checkpoint(self, enabled: bool):
orig_ckpt_func = torch.utils.checkpoint.CheckpointFunction
class PatchedCheckpointFunction(torch.autograd.Function):
-
@staticmethod
def forward(ctx, run_function, preserve_rng_state, *args):
# signal that the current tracing occurs within activation checkpoint part
@@ -305,7 +320,8 @@ def forward(ctx, run_function, preserve_rng_state, *args):
@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
raise NotImplementedError(
- "We do not implement the backward pass as we only trace the forward pass.")
+ "We do not implement the backward pass as we only trace the forward pass."
+ )
# override the checkpoint function
torch.utils.checkpoint.CheckpointFunction = PatchedCheckpointFunction
@@ -356,10 +372,13 @@ def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cac
if attr_val is p:
if n not in parameter_proxy_cache:
kwargs = {}
- if 'proxy_factory_fn' in inspect.signature(self.create_proxy).parameters:
- kwargs['proxy_factory_fn'] = (None if not self.param_shapes_constant else
- lambda node: ColoProxy(self, node, n, attr_val))
- val_proxy = self.create_proxy('get_attr', n, (), {}, **kwargs) # type: ignore[arg-type]
+ if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters:
+ kwargs["proxy_factory_fn"] = (
+ None
+ if not self.param_shapes_constant
+ else lambda node: ColoProxy(self, node, n, attr_val)
+ )
+ val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
parameter_proxy_cache[n] = val_proxy
return parameter_proxy_cache[n]
return None
@@ -370,8 +389,9 @@ def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cac
return maybe_buffer_proxy
if isinstance(attr_val, torch.nn.Parameter):
- maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(),
- parameter_proxy_cache)
+ maybe_parameter_proxy = maybe_get_proxy_for_attr(
+ attr_val, self.root.named_parameters(), parameter_proxy_cache
+ )
if maybe_parameter_proxy is not None:
return maybe_parameter_proxy
@@ -389,42 +409,41 @@ def symbolic_trace(
if meta_args is not None:
root.to(default_device())
wrap_fn = lambda x: MetaTensor(x, fake_device=default_device()) if isinstance(x, torch.Tensor) else x
- graph = ColoTracer(trace_act_ckpt=trace_act_ckpt).trace(root,
- concrete_args=concrete_args,
- meta_args=tree_map(wrap_fn, meta_args))
+ graph = ColoTracer(trace_act_ckpt=trace_act_ckpt).trace(
+ root, concrete_args=concrete_args, meta_args=tree_map(wrap_fn, meta_args)
+ )
root.cpu()
else:
graph = Tracer().trace(root, concrete_args=concrete_args)
else:
from .tracer import ColoTracer as OrigColoTracer
- graph = OrigColoTracer(trace_act_ckpt=trace_act_ckpt).trace(root,
- concrete_args=concrete_args,
- meta_args=meta_args)
+
+ graph = OrigColoTracer(trace_act_ckpt=trace_act_ckpt).trace(
+ root, concrete_args=concrete_args, meta_args=meta_args
+ )
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
return ColoGraphModule(root, graph, name)
@compatibility(is_backward_compatible=False)
class _TorchTensorOverride(object):
-
def __init__(self, tracer: Tracer):
self.overrides = {}
self.tracer = tracer
def __enter__(self):
-
def wrap_tensor_method(target):
-
@functools.wraps(target)
def wrapper(*args, **kwargs):
is_proxy = any(isinstance(p, ColoProxy) for p in args) | any(
- isinstance(p, ColoProxy) for p in kwargs.values())
+ isinstance(p, ColoProxy) for p in kwargs.values()
+ )
if is_proxy:
# if the arg is a proxy, then need to record this function called on this proxy
# e.g. torch.ones(size) where size is an input proxy
self.tracer._disable_module_getattr = True
try:
- proxy = self.tracer.create_proxy('call_function', target, args, kwargs)
+ proxy = self.tracer.create_proxy("call_function", target, args, kwargs)
finally:
self.tracer._disable_module_getattr = False
return proxy
@@ -446,11 +465,12 @@ def __exit__(self, exc_type, exc_val, exc_tb):
setattr(torch, name, orig)
-def meta_prop_pass(gm: ColoGraphModule,
- root: torch.nn.Module,
- meta_args: Optional[Dict[str, Any]] = None,
- concrete_args: Optional[Dict[str, torch.Tensor]] = None):
-
+def meta_prop_pass(
+ gm: ColoGraphModule,
+ root: torch.nn.Module,
+ meta_args: Optional[Dict[str, Any]] = None,
+ concrete_args: Optional[Dict[str, torch.Tensor]] = None,
+):
if meta_args is None:
meta_args = {}
@@ -465,36 +485,36 @@ def meta_prop_pass(gm: ColoGraphModule,
# update concrete args with default values
non_meta_arg_names = sig_names - meta_arg_names
for k, v in sig.parameters.items():
- if k in non_meta_arg_names and \
- k not in concrete_args and \
- v.default is not inspect.Parameter.empty:
+ if k in non_meta_arg_names and k not in concrete_args and v.default is not inspect.Parameter.empty:
concrete_args[k] = v.default
for node in gm.graph.nodes:
- node._meta_data = _meta_data_computing(meta_args, concrete_args, root, node.op, node.target, node.args,
- node.kwargs)
+ node._meta_data = _meta_data_computing(
+ meta_args, concrete_args, root, node.op, node.target, node.args, node.kwargs
+ )
def _meta_data_computing(meta_args, concrete_args, root, kind, target, args, kwargs):
unwrap_fn = lambda n: n._meta_data if isinstance(n, Node) else n
- if kind == 'placeholder':
+ if kind == "placeholder":
meta_out = meta_args[target] if target in meta_args else concrete_args.get(_truncate_suffix(target), None)
- elif kind == 'get_attr':
+ elif kind == "get_attr":
attr_itr = root
atoms = target.split(".")
for atom in atoms:
attr_itr = getattr(attr_itr, atom)
meta_out = attr_itr
- elif kind == 'call_function':
+ elif kind == "call_function":
meta_out = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
- elif kind == 'call_method':
- if target == '__call__':
+ elif kind == "call_method":
+ if target == "__call__":
meta_out = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs))
else:
if target not in _TensorPropertyMethod:
- meta_out = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]),
- **tree_map(unwrap_fn, kwargs))
- elif kind == 'call_module':
+ meta_out = getattr(unwrap_fn(args[0]), target)(
+ *tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs)
+ )
+ elif kind == "call_module":
mod = root.get_submodule(target)
meta_out = mod.forward(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
else:
@@ -603,26 +623,30 @@ def wrap_fn(n):
if kind == "call_function":
if bias_addition_function.has(target):
if target == torch.nn.functional.linear:
- if 'bias' in kwargs and kwargs['bias'] is not None:
+ if "bias" in kwargs and kwargs["bias"] is not None:
function_to_substitute = func_to_func_dict[target]
- handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy,
- function_to_substitute)
+ handle = bias_addition_function.get(target)(
+ tracer, target, args_proxy, kwargs_proxy, function_to_substitute
+ )
else:
function_to_substitute = func_to_func_dict[target]
- handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy,
- function_to_substitute)
+ handle = bias_addition_function.get(target)(
+ tracer, target, args_proxy, kwargs_proxy, function_to_substitute
+ )
elif bias_addition_function.has(target.__name__):
# use name for some builtin op like @ (matmul)
function_to_substitute = func_to_func_dict[target]
- handle = bias_addition_function.get(target.__name__)(tracer, target, args_proxy, kwargs_proxy,
- function_to_substitute)
+ handle = bias_addition_function.get(target.__name__)(
+ tracer, target, args_proxy, kwargs_proxy, function_to_substitute
+ )
elif kind == "call_method":
method = getattr(args_metas[0].__class__, target)
if bias_addition_method.has(method):
function_to_substitute = method_to_func_dict[method]
- handle = bias_addition_method.get(method)(tracer, target, args_proxy, kwargs_proxy,
- function_to_substitute)
+ handle = bias_addition_method.get(method)(
+ tracer, target, args_proxy, kwargs_proxy, function_to_substitute
+ )
elif kind == "call_module":
# if not hasattr(self, "orig_forward"):
@@ -631,8 +655,9 @@ def wrap_fn(n):
mod_type = type(mod)
if bias_addition_module.has(mod_type) and mod.bias is not None:
function_to_substitute = module_to_func_dict[mod_type]
- handle = bias_addition_module.get(mod_type)(tracer, target, args_proxy, kwargs_proxy,
- function_to_substitute)
+ handle = bias_addition_module.get(mod_type)(
+ tracer, target, args_proxy, kwargs_proxy, function_to_substitute
+ )
if handle is not None:
handle.generate()
diff --git a/colossalai/fx/tracer/meta_patch/patched_function/activation_function.py b/colossalai/fx/tracer/meta_patch/patched_function/activation_function.py
index 12c42514895e..75d7b18a067c 100644
--- a/colossalai/fx/tracer/meta_patch/patched_function/activation_function.py
+++ b/colossalai/fx/tracer/meta_patch/patched_function/activation_function.py
@@ -5,4 +5,4 @@
@meta_patched_function.register(torch.nn.functional.relu)
def torch_nn_func_relu(input, inplace=False):
- return torch.empty(input.shape, device='meta')
+ return torch.empty(input.shape, device="meta")
diff --git a/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py b/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py
index 042b92c5847a..3475f22e3b19 100644
--- a/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py
+++ b/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py
@@ -4,7 +4,7 @@
@meta_patched_function.register(torch.matmul)
-@meta_patched_function.register('matmul') # for built-in op @
+@meta_patched_function.register("matmul") # for built-in op @
def torch_matmul(input, other, *, out=None):
# copied from huggingface.utils.fx
d1 = input.dim()
@@ -44,8 +44,8 @@ def torch_matmul(input, other, *, out=None):
@meta_patched_function.register(torch.abs)
def torch_abs(input, *, out=None):
- assert out is None, 'out is not supported yet'
- return torch.empty(input.shape, device='meta')
+ assert out is None, "out is not supported yet"
+ return torch.empty(input.shape, device="meta")
@meta_patched_function.register(torch.bmm)
@@ -89,7 +89,7 @@ def torch_addmm(input, mat1, mat2, *, beta=1, alpha=1, out=None):
@meta_patched_function.register(torch.var_mean)
def torch_var_mean(input, dim, unbiased=True, keepdim=False, *, out=None):
- assert out is None, 'saving to out is not supported yet'
- var = torch.empty(1).squeeze(0).to('meta')
- mean = torch.empty(1).squeeze(0).to('meta')
+ assert out is None, "saving to out is not supported yet"
+ var = torch.empty(1).squeeze(0).to("meta")
+ mean = torch.empty(1).squeeze(0).to("meta")
return var, mean
diff --git a/colossalai/fx/tracer/meta_patch/patched_function/convolution.py b/colossalai/fx/tracer/meta_patch/patched_function/convolution.py
index 8500e5c82508..26daf32a2afc 100644
--- a/colossalai/fx/tracer/meta_patch/patched_function/convolution.py
+++ b/colossalai/fx/tracer/meta_patch/patched_function/convolution.py
@@ -8,7 +8,6 @@
def _ntuple(n, name="parse"):
-
def parse(x):
if isinstance(x, collections.abc.Iterable):
return tuple(x)
@@ -24,21 +23,21 @@ def parse(x):
def _extract_kwargs(kwargs):
- if 'stride' in kwargs:
- stride = kwargs['stride']
+ if "stride" in kwargs:
+ stride = kwargs["stride"]
else:
stride = 1
# TODO: process str type padding
- if 'padding' in kwargs:
- padding = kwargs['padding']
+ if "padding" in kwargs:
+ padding = kwargs["padding"]
else:
padding = 0
- if 'dilation' in kwargs:
- dilation = kwargs['dilation']
+ if "dilation" in kwargs:
+ dilation = kwargs["dilation"]
else:
dilation = 1
- if 'output_padding' in kwargs:
- output_padding = kwargs['output_padding']
+ if "output_padding" in kwargs:
+ output_padding = kwargs["output_padding"]
else:
output_padding = 0
@@ -61,7 +60,7 @@ def torch_nn_functional_conv1d(input, weight, **kwargs):
c_out,
l_out,
)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_function.register(torch.nn.functional.conv2d)
@@ -82,7 +81,7 @@ def torch_nn_functional_conv2d(input, weight, **kwargs):
h_out,
w_out,
)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_function.register(torch.nn.functional.conv3d)
@@ -105,7 +104,7 @@ def torch_nn_functional_conv3d(input, weight, **kwargs):
h_out,
w_out,
)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_function.register(torch.nn.functional.conv_transpose1d)
@@ -120,13 +119,14 @@ def torch_nn_functional_convtranspose1d(input, weight, **kwargs):
kernel_size = weight.shape[2:]
l_in = input.shape[-1]
c_out = weight.shape[1]
- l_out = math.floor((l_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) +
- output_padding[0] + 1)
+ l_out = math.floor(
+ (l_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + output_padding[0] + 1
+ )
result_shape = input.shape[:-2] + (
c_out,
l_out,
)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_function.register(torch.nn.functional.conv_transpose2d)
@@ -141,16 +141,18 @@ def torch_nn_functional_convtranspose2d(input, weight, **kwargs):
kernel_size = weight.shape[2:]
h_in, w_in = input.shape[-2:]
c_out = weight.shape[1]
- h_out = math.floor((h_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) +
- output_padding[0] + 1)
- w_out = math.floor((w_in - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) +
- output_padding[1] + 1)
+ h_out = math.floor(
+ (h_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + output_padding[0] + 1
+ )
+ w_out = math.floor(
+ (w_in - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) + output_padding[1] + 1
+ )
result_shape = input.shape[:-3] + (
c_out,
h_out,
w_out,
)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_function.register(torch.nn.functional.conv_transpose3d)
@@ -165,16 +167,19 @@ def torch_nn_functional_convtranspose3d(input, weight, **kwargs):
kernel_size = weight.shape[2:]
d_in, h_in, w_in = input.shape[-3:]
c_out = weight.shape[1]
- d_out = math.floor((d_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) +
- output_padding[0] + 1)
- h_out = math.floor((h_in - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) +
- output_padding[1] + 1)
- w_out = math.floor((w_in - 1) * stride[2] - 2 * padding[2] + dilation[2] * (kernel_size[2] - 1) +
- output_padding[2] + 1)
+ d_out = math.floor(
+ (d_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + output_padding[0] + 1
+ )
+ h_out = math.floor(
+ (h_in - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) + output_padding[1] + 1
+ )
+ w_out = math.floor(
+ (w_in - 1) * stride[2] - 2 * padding[2] + dilation[2] * (kernel_size[2] - 1) + output_padding[2] + 1
+ )
result_shape = input.shape[:-4] + (
c_out,
d_out,
h_out,
w_out,
)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
diff --git a/colossalai/fx/tracer/meta_patch/patched_function/embedding.py b/colossalai/fx/tracer/meta_patch/patched_function/embedding.py
index 6d8d864ea29a..27a79f18590a 100644
--- a/colossalai/fx/tracer/meta_patch/patched_function/embedding.py
+++ b/colossalai/fx/tracer/meta_patch/patched_function/embedding.py
@@ -4,11 +4,7 @@
@meta_patched_function.register(torch.nn.functional.embedding)
-def torch_nn_functional_embedding(input,
- weight,
- padding_idx=None,
- max_norm=None,
- norm_type=2.0,
- scale_grad_by_freq=False,
- sparse=False):
+def torch_nn_functional_embedding(
+ input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False
+):
return torch.empty(*input.shape, weight.shape[-1], device="meta")
diff --git a/colossalai/fx/tracer/meta_patch/patched_function/normalization.py b/colossalai/fx/tracer/meta_patch/patched_function/normalization.py
index e9e7eda6159c..8a6214990830 100644
--- a/colossalai/fx/tracer/meta_patch/patched_function/normalization.py
+++ b/colossalai/fx/tracer/meta_patch/patched_function/normalization.py
@@ -5,16 +5,11 @@
@meta_patched_function.register(torch.nn.functional.layer_norm)
def torch_nn_func_layernorm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
- return torch.empty(input.shape, device='meta')
+ return torch.empty(input.shape, device="meta")
@meta_patched_function.register(torch.nn.functional.batch_norm)
-def torch_nn_func_batchnorm(input,
- running_mean,
- running_var,
- weight=None,
- bias=None,
- training=False,
- momentum=0.1,
- eps=1e-05):
- return torch.empty(input.shape, device='meta')
+def torch_nn_func_batchnorm(
+ input, running_mean, running_var, weight=None, bias=None, training=False, momentum=0.1, eps=1e-05
+):
+ return torch.empty(input.shape, device="meta")
diff --git a/colossalai/fx/tracer/meta_patch/patched_function/python_ops.py b/colossalai/fx/tracer/meta_patch/patched_function/python_ops.py
index 4c171cb10991..7642934a409b 100644
--- a/colossalai/fx/tracer/meta_patch/patched_function/python_ops.py
+++ b/colossalai/fx/tracer/meta_patch/patched_function/python_ops.py
@@ -19,9 +19,9 @@ def to_concrete(t):
return t
def _slice_convert(slice_obj):
- attrs = {'start': slice_obj.start, 'stop': slice_obj.stop, 'step': slice_obj.step}
+ attrs = {"start": slice_obj.start, "stop": slice_obj.stop, "step": slice_obj.step}
new_attrs = _slice_attr_convert(attrs)
- attr_dict_to_tuple = (new_attrs['start'], new_attrs['stop'], new_attrs['step'])
+ attr_dict_to_tuple = (new_attrs["start"], new_attrs["stop"], new_attrs["step"])
return slice(*attr_dict_to_tuple)
def _slice_attr_convert(attrs):
diff --git a/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py b/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py
index b14ff10ce137..c61e1c4dc9e1 100644
--- a/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py
+++ b/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py
@@ -105,14 +105,15 @@ def torch_cat(tensors, dim=None, axis=None, *, out=None):
shapes = [t.shape for t in tensors]
shape = list(shapes[0])
concatenated_dim = sum(shape[dim] for shape in shapes)
- final_shape = shape[:dim] + [concatenated_dim] + shape[dim + 1:]
+ final_shape = shape[:dim] + [concatenated_dim] + shape[dim + 1 :]
return torch.empty(final_shape, device="meta")
@meta_patched_function.register(torch.repeat_interleave)
def torch_repeat_interleave(input, repeats, dim=None, output_size=None):
- assert isinstance(repeats, int) or isinstance(repeats, torch.Tensor), \
- "Argument 'repeats' should be of type 'torch.Tensor' or 'int'"
+ assert isinstance(repeats, int) or isinstance(
+ repeats, torch.Tensor
+ ), "Argument 'repeats' should be of type 'torch.Tensor' or 'int'"
shape = list(input.shape) if dim is not None else [input.numel()]
dim = dim if dim is not None else 0
@@ -132,36 +133,36 @@ def torch_tensor_repeat_interleave(self, repeats, dim=None, *, output_size=None)
@meta_patched_function.register(torch.roll)
def torch_roll(input, shifts, dims=None):
- return torch.empty(input.shape, device='meta')
+ return torch.empty(input.shape, device="meta")
@meta_patched_function.register(torch.full)
def torch_full(size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False):
- assert out is None, 'assigning result to out is not supported yet'
- return torch.empty(size, device='meta', dtype=dtype, layout=layout, requires_grad=requires_grad)
+ assert out is None, "assigning result to out is not supported yet"
+ return torch.empty(size, device="meta", dtype=dtype, layout=layout, requires_grad=requires_grad)
@meta_patched_function.register(torch.max)
def torch_max(input, dim=None, keepdim=False, *, out=None):
- assert out is None, 'assigning value to out is not supported yet'
+ assert out is None, "assigning value to out is not supported yet"
if dim is not None:
if isinstance(dim, int):
shape = list(input.shape)
shape.pop(dim)
if keepdim:
shape.insert(dim, 1)
- return torch.empty(shape, device='meta', dtype=input.dtype), torch.empty(shape,
- device='meta',
- dtype=input.dtype)
+ return torch.empty(shape, device="meta", dtype=input.dtype), torch.empty(
+ shape, device="meta", dtype=input.dtype
+ )
elif isinstance(dim, torch.Tensor):
# when dim is a 0D or 1D tensor, it will maintain the same shape
num_dims = dim.dim()
if num_dims in [0, 1]:
- return torch.empty_like(input, device='meta')
+ return torch.empty_like(input, device="meta")
else:
raise ValueError(f"Expected dim to a 0D or 1D tensor but got {num_dims} dimensions")
else:
- return torch.empty([], device='meta', dtype=input.dtype)
+ return torch.empty([], device="meta", dtype=input.dtype)
@meta_patched_function.register(torch.Tensor.cpu)
diff --git a/colossalai/fx/tracer/meta_patch/patched_module/__init__.py b/colossalai/fx/tracer/meta_patch/patched_module/__init__.py
index e28e52585fff..3f40ec2a67ee 100644
--- a/colossalai/fx/tracer/meta_patch/patched_module/__init__.py
+++ b/colossalai/fx/tracer/meta_patch/patched_module/__init__.py
@@ -4,4 +4,4 @@
from .linear import *
from .normalization import *
from .pooling import *
-from .rnn import *
\ No newline at end of file
+from .rnn import *
diff --git a/colossalai/fx/tracer/meta_patch/patched_module/activation_function.py b/colossalai/fx/tracer/meta_patch/patched_module/activation_function.py
index d03da6588c1c..aa2ede187d37 100644
--- a/colossalai/fx/tracer/meta_patch/patched_module/activation_function.py
+++ b/colossalai/fx/tracer/meta_patch/patched_module/activation_function.py
@@ -10,4 +10,4 @@
@meta_patched_module.register(torch.nn.ReLU6)
@meta_patched_module.register(torch.nn.PReLU)
def torch_nn_non_linear_act(self, input):
- return torch.empty(input.shape, device='meta')
+ return torch.empty(input.shape, device="meta")
diff --git a/colossalai/fx/tracer/meta_patch/patched_module/convolution.py b/colossalai/fx/tracer/meta_patch/patched_module/convolution.py
index cf9f3487aac9..35173a68a0be 100644
--- a/colossalai/fx/tracer/meta_patch/patched_module/convolution.py
+++ b/colossalai/fx/tracer/meta_patch/patched_module/convolution.py
@@ -11,13 +11,14 @@ def torch_nn_conv1d(self, input):
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv1d
l_in = input.shape[-1]
c_out = self.out_channels
- l_out = math.floor((l_in + 2 * self.padding[0] - self.dilation[0] *
- (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
+ l_out = math.floor(
+ (l_in + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
+ )
result_shape = input.shape[:-2] + (
c_out,
l_out,
)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.Conv2d)
@@ -26,16 +27,18 @@ def torch_nn_conv2d(self, input):
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv2d
h_in, w_in = input.shape[-2:]
c_out = self.out_channels
- h_out = math.floor((h_in + 2 * self.padding[0] - self.dilation[0] *
- (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
- w_out = math.floor((w_in + 2 * self.padding[1] - self.dilation[1] *
- (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1)
+ h_out = math.floor(
+ (h_in + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
+ )
+ w_out = math.floor(
+ (w_in + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1
+ )
result_shape = input.shape[:-3] + (
c_out,
h_out,
w_out,
)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.Conv3d)
@@ -44,19 +47,22 @@ def torch_nn_conv3d(self, input):
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv3d
d_in, h_in, w_in = input.shape[-3:]
c_out = self.out_channels
- d_out = math.floor((d_in + 2 * self.padding[0] - self.dilation[0] *
- (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
- h_out = math.floor((h_in + 2 * self.padding[1] - self.dilation[1] *
- (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1)
- w_out = math.floor((w_in + 2 * self.padding[2] - self.dilation[2] *
- (self.kernel_size[2] - 1) - 1) / self.stride[2] + 1)
+ d_out = math.floor(
+ (d_in + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
+ )
+ h_out = math.floor(
+ (h_in + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1
+ )
+ w_out = math.floor(
+ (w_in + 2 * self.padding[2] - self.dilation[2] * (self.kernel_size[2] - 1) - 1) / self.stride[2] + 1
+ )
result_shape = input.shape[:-4] + (
c_out,
d_out,
h_out,
w_out,
)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.ConvTranspose1d)
@@ -65,13 +71,18 @@ def torch_nn_convtranspose1d(self, input):
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html
l_in = input.shape[-1]
c_out = self.out_channels
- l_out = math.floor((l_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] *
- (self.kernel_size[0] - 1) + self.output_padding[0] + 1)
+ l_out = math.floor(
+ (l_in - 1) * self.stride[0]
+ - 2 * self.padding[0]
+ + self.dilation[0] * (self.kernel_size[0] - 1)
+ + self.output_padding[0]
+ + 1
+ )
result_shape = input.shape[:-2] + (
c_out,
l_out,
)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.ConvTranspose2d)
@@ -80,16 +91,26 @@ def torch_nn_convtranspose2d(self, input):
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
h_in, w_in = input.shape[-2:]
c_out = self.out_channels
- h_out = math.floor((h_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] *
- (self.kernel_size[0] - 1) + self.output_padding[0] + 1)
- w_out = math.floor((w_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] *
- (self.kernel_size[1] - 1) + self.output_padding[1] + 1)
+ h_out = math.floor(
+ (h_in - 1) * self.stride[0]
+ - 2 * self.padding[0]
+ + self.dilation[0] * (self.kernel_size[0] - 1)
+ + self.output_padding[0]
+ + 1
+ )
+ w_out = math.floor(
+ (w_in - 1) * self.stride[1]
+ - 2 * self.padding[1]
+ + self.dilation[1] * (self.kernel_size[1] - 1)
+ + self.output_padding[1]
+ + 1
+ )
result_shape = input.shape[:-3] + (
c_out,
h_out,
w_out,
)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.ConvTranspose3d)
@@ -98,16 +119,31 @@ def torch_nn_convtranspose3d(self, input):
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose3d.html
d_in, h_in, w_in = input.shape[-3:]
c_out = self.out_channels
- d_out = math.floor((d_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] *
- (self.kernel_size[0] - 1) + self.output_padding[0] + 1)
- h_out = math.floor((h_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] *
- (self.kernel_size[1] - 1) + self.output_padding[1] + 1)
- w_out = math.floor((w_in - 1) * self.stride[2] - 2 * self.padding[2] + self.dilation[2] *
- (self.kernel_size[2] - 1) + self.output_padding[2] + 1)
+ d_out = math.floor(
+ (d_in - 1) * self.stride[0]
+ - 2 * self.padding[0]
+ + self.dilation[0] * (self.kernel_size[0] - 1)
+ + self.output_padding[0]
+ + 1
+ )
+ h_out = math.floor(
+ (h_in - 1) * self.stride[1]
+ - 2 * self.padding[1]
+ + self.dilation[1] * (self.kernel_size[1] - 1)
+ + self.output_padding[1]
+ + 1
+ )
+ w_out = math.floor(
+ (w_in - 1) * self.stride[2]
+ - 2 * self.padding[2]
+ + self.dilation[2] * (self.kernel_size[2] - 1)
+ + self.output_padding[2]
+ + 1
+ )
result_shape = input.shape[:-4] + (
c_out,
d_out,
h_out,
w_out,
)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
diff --git a/colossalai/fx/tracer/meta_patch/patched_module/embedding.py b/colossalai/fx/tracer/meta_patch/patched_module/embedding.py
index 999e33b17c1c..f28647e9caa5 100644
--- a/colossalai/fx/tracer/meta_patch/patched_module/embedding.py
+++ b/colossalai/fx/tracer/meta_patch/patched_module/embedding.py
@@ -6,4 +6,4 @@
@meta_patched_module.register(torch.nn.Embedding)
def torch_nn_embedding(self, input):
result_shape = input.shape + (self.embedding_dim,)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
diff --git a/colossalai/fx/tracer/meta_patch/patched_module/linear.py b/colossalai/fx/tracer/meta_patch/patched_module/linear.py
index 56f13bf97532..97e6b0e96e83 100644
--- a/colossalai/fx/tracer/meta_patch/patched_module/linear.py
+++ b/colossalai/fx/tracer/meta_patch/patched_module/linear.py
@@ -6,5 +6,7 @@
@meta_patched_module.register(torch.nn.Linear)
def torch_nn_linear(self, input):
last_dim = input.shape[-1]
- assert last_dim == self.in_features, f'Expected hidden size {self.in_features} but got {last_dim} for the torch.nn.Linear patch'
+ assert (
+ last_dim == self.in_features
+ ), f"Expected hidden size {self.in_features} but got {last_dim} for the torch.nn.Linear patch"
return torch.empty(input.shape[:-1] + (self.out_features,), device="meta")
diff --git a/colossalai/fx/tracer/meta_patch/patched_module/normalization.py b/colossalai/fx/tracer/meta_patch/patched_module/normalization.py
index c21ff64cf3de..198e72e342b1 100644
--- a/colossalai/fx/tracer/meta_patch/patched_module/normalization.py
+++ b/colossalai/fx/tracer/meta_patch/patched_module/normalization.py
@@ -23,6 +23,7 @@ def torch_nn_normalize(self, input):
try:
import apex
+
meta_patched_module.register(apex.normalization.FusedLayerNorm)(torch_nn_normalize)
meta_patched_module.register(apex.normalization.FusedRMSNorm)(torch_nn_normalize)
meta_patched_module.register(apex.normalization.MixedFusedLayerNorm)(torch_nn_normalize)
diff --git a/colossalai/fx/tracer/meta_patch/patched_module/pooling.py b/colossalai/fx/tracer/meta_patch/patched_module/pooling.py
index 7ce23fbf7ac9..450586d02f8f 100644
--- a/colossalai/fx/tracer/meta_patch/patched_module/pooling.py
+++ b/colossalai/fx/tracer/meta_patch/patched_module/pooling.py
@@ -8,7 +8,7 @@
@meta_patched_module.register(torch.nn.AvgPool1d)
def torch_nn_avgpool1d(self, input):
num_dim = input.dim()
- assert num_dim in [2, 3], f'expected the input to have 2 or 3 dimensions, but got {num_dim} dimensions'
+ assert num_dim in [2, 3], f"expected the input to have 2 or 3 dimensions, but got {num_dim} dimensions"
l_in = input.shape[-1]
@@ -25,13 +25,13 @@ def _convert_int_to_list(item):
l_out = math.floor((l_in + 2 * padding[0] - kernel_size[0]) / stride[0] + 1)
result_shape = tuple(input.shape[:-1]) + (l_out,)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.AvgPool2d)
def torch_nn_avgpool2d(self, input):
num_dim = input.dim()
- assert num_dim in [3, 4], f'expected the input to have 3 or 4 dimensions, but got {num_dim} dimensions'
+ assert num_dim in [3, 4], f"expected the input to have 3 or 4 dimensions, but got {num_dim} dimensions"
h_in, w_in = input.shape[-2:]
@@ -52,13 +52,13 @@ def _convert_int_to_list(item):
h_out,
w_out,
)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.AvgPool3d)
def torch_nn_avgpool3d(self, input):
num_dim = input.dim()
- assert num_dim in [4, 5], f'expected the input to have 4 or 5 dimensions, but got {num_dim} dimensions'
+ assert num_dim in [4, 5], f"expected the input to have 4 or 5 dimensions, but got {num_dim} dimensions"
d_in, h_in, w_in = input.shape[-3:]
@@ -81,13 +81,13 @@ def _convert_int_to_list(item):
h_out,
w_out,
)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.MaxPool1d)
def torch_nn_maxpool1d(self, input):
num_dim = input.dim()
- assert num_dim in [2, 3], f'expected the input to have 2 or 3 dimensions, but got {num_dim} dimensions'
+ assert num_dim in [2, 3], f"expected the input to have 2 or 3 dimensions, but got {num_dim} dimensions"
l_in = input.shape[-1]
@@ -105,13 +105,13 @@ def _convert_int_to_list(item):
l_out = math.floor((l_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1)
result_shape = tuple(input.shape[:-1]) + (l_out,)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.MaxPool2d)
def torch_nn_maxpool2d(self, input):
num_dim = input.dim()
- assert num_dim in [3, 4], f'expected the input to have 3 or 4 dimensions, but got {num_dim} dimensions'
+ assert num_dim in [3, 4], f"expected the input to have 3 or 4 dimensions, but got {num_dim} dimensions"
h_in, w_in = input.shape[-2:]
@@ -133,13 +133,13 @@ def _convert_int_to_list(item):
h_out,
w_out,
)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.MaxPool3d)
def torch_nn_maxpool3d(self, input):
num_dim = input.dim()
- assert num_dim in [4, 5], f'expected the input to have 4 or 5 dimensions, but got {num_dim} dimensions'
+ assert num_dim in [4, 5], f"expected the input to have 4 or 5 dimensions, but got {num_dim} dimensions"
d_in, h_in, w_in = input.shape[-3:]
@@ -163,7 +163,7 @@ def _convert_int_to_list(item):
h_out,
w_out,
)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.AdaptiveAvgPool1d)
@@ -175,7 +175,7 @@ def torch_nn_adapative_pooling_1d(self, input):
else:
output_size = self.output_size
result_shape = tuple(input.shape[:-1]) + output_size
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.AdaptiveAvgPool2d)
@@ -187,7 +187,7 @@ def torch_nn_adapative_pooling_2d(self, input):
else:
output_size = self.output_size
result_shape = tuple(input.shape[:-2]) + output_size
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.AdaptiveAvgPool3d)
@@ -199,4 +199,4 @@ def torch_nn_adapative_pooling_3d(self, input):
else:
output_size = self.output_size
result_shape = tuple(input.shape[:-3]) + output_size
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
diff --git a/colossalai/fx/tracer/meta_patch/patched_module/rnn.py b/colossalai/fx/tracer/meta_patch/patched_module/rnn.py
index ee15ca34162e..bfb7ed171186 100644
--- a/colossalai/fx/tracer/meta_patch/patched_module/rnn.py
+++ b/colossalai/fx/tracer/meta_patch/patched_module/rnn.py
@@ -1,5 +1,3 @@
-from typing import Optional
-
import torch
from ...registry import meta_patched_module
@@ -8,9 +6,11 @@
@meta_patched_module.register(torch.nn.GRU)
@meta_patched_module.register(torch.nn.RNN)
def torch_nn_rnn(self, input, hx):
- assert input.shape[
- -1] == self.input_size, f'Expected input to have input size {self.input_size} but got {input.shape[-1]} for the torch.nn.RNN patch'
- assert hx.shape[
- -1] == self.hidden_size, f'Expected hx to have hidden size {self.hidden_size} but got {hx.shape[-1]} for the torch.nn.RNN patch'
+ assert (
+ input.shape[-1] == self.input_size
+ ), f"Expected input to have input size {self.input_size} but got {input.shape[-1]} for the torch.nn.RNN patch"
+ assert (
+ hx.shape[-1] == self.hidden_size
+ ), f"Expected hx to have hidden size {self.hidden_size} but got {hx.shape[-1]} for the torch.nn.RNN patch"
d = 2 if self.bidirectional else 1
return torch.empty(input.shape[:-1] + (self.hidden_size * d,), device="meta"), hx
diff --git a/colossalai/fx/tracer/registry.py b/colossalai/fx/tracer/registry.py
index 12fc6de73d44..80b3868bb4fe 100644
--- a/colossalai/fx/tracer/registry.py
+++ b/colossalai/fx/tracer/registry.py
@@ -1,11 +1,9 @@
class PatchRegistry:
-
def __init__(self, name):
self.name = name
self.store = {}
def register(self, source):
-
def wrapper(func):
self.store[source] = func
return func
@@ -21,8 +19,8 @@ def has(self, source):
return source in self.store
-meta_patched_function = PatchRegistry(name='patched_functions_for_meta_execution')
-meta_patched_module = PatchRegistry(name='patched_modules_for_meta_execution')
-bias_addition_function = PatchRegistry(name='patched_function_for_bias_addition')
-bias_addition_module = PatchRegistry(name='patched_module_for_bias_addition')
-bias_addition_method = PatchRegistry(name='patched_method_for_bias_addition')
+meta_patched_function = PatchRegistry(name="patched_functions_for_meta_execution")
+meta_patched_module = PatchRegistry(name="patched_modules_for_meta_execution")
+bias_addition_function = PatchRegistry(name="patched_function_for_bias_addition")
+bias_addition_module = PatchRegistry(name="patched_module_for_bias_addition")
+bias_addition_method = PatchRegistry(name="patched_method_for_bias_addition")
diff --git a/colossalai/fx/tracer/tracer.py b/colossalai/fx/tracer/tracer.py
index 28965a1b8e74..d9cb587b5d39 100644
--- a/colossalai/fx/tracer/tracer.py
+++ b/colossalai/fx/tracer/tracer.py
@@ -29,7 +29,7 @@
meta_patched_module,
)
-__all__ = ['ColoTracer']
+__all__ = ["ColoTracer"]
class TracerType(enum.Enum):
@@ -103,7 +103,7 @@ def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, pr
if kind == "call_function":
if bias_addition_function.has(target):
if target == torch.nn.functional.linear:
- if 'bias' in kwargs and kwargs['bias'] is not None:
+ if "bias" in kwargs and kwargs["bias"] is not None:
function_to_substitute = func_to_func_dict[target]
handle = bias_addition_function.get(target)(self, target, args, kwargs, function_to_substitute)
else:
@@ -160,22 +160,27 @@ def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cac
if n not in parameter_proxy_cache:
kwargs = {}
if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters:
- kwargs["proxy_factory_fn"] = (None if not self.param_shapes_constant else
- lambda node: ParameterProxy(self, node, n, attr_val))
- val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
+ kwargs["proxy_factory_fn"] = (
+ None
+ if not self.param_shapes_constant
+ else lambda node: ParameterProxy(self, node, n, attr_val)
+ )
+ val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
parameter_proxy_cache[n] = val_proxy
return parameter_proxy_cache[n]
return None
if isinstance(attr_val, torch.nn.Parameter):
- maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(),
- parameter_proxy_cache)
+ maybe_parameter_proxy = maybe_get_proxy_for_attr(
+ attr_val, self.root.named_parameters(), parameter_proxy_cache
+ )
if maybe_parameter_proxy is not None:
return maybe_parameter_proxy
if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
- maybe_buffer_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_buffers(),
- parameter_proxy_cache)
+ maybe_buffer_proxy = maybe_get_proxy_for_attr(
+ attr_val, self.root.named_buffers(), parameter_proxy_cache
+ )
if maybe_buffer_proxy is not None:
return maybe_buffer_proxy
@@ -190,7 +195,7 @@ def call_module(self, m, forward, args, kwargs):
# if a customized or third-party module like apex.normalization.FusedRMSNorm is patched,
# we should treat it as leaf module as well
if meta_patched_module.has(m.__class__) or self.is_leaf_module(m, module_qualified_name):
- return self.create_proxy('call_module', module_qualified_name, args, kwargs)
+ return self.create_proxy("call_module", module_qualified_name, args, kwargs)
else:
return forward(*args, **kwargs)
@@ -211,7 +216,6 @@ def _configure_tracer_type(self, tracer_type: TracerType):
raise ValueError(f"Unrecognized tracer type {tracer_type}")
def _meta_data_computing(self, kind, target, args, kwargs):
-
if kind == "placeholder" and target in self.meta_args and self.meta_args[target].is_meta:
meta_out = self.meta_args[target]
return meta_out
@@ -235,8 +239,9 @@ def _meta_data_computing(self, kind, target, args, kwargs):
# Therefore, I need to record the nn.parameter.Parameter attribute for the operation
# added by the bias addition manipulation following the get_attr node.
convert_to_parameter = False
- if target in (torch.transpose, torch.reshape) and isinstance(args_metas[0],
- torch.nn.parameter.Parameter):
+ if target in (torch.transpose, torch.reshape) and isinstance(
+ args_metas[0], torch.nn.parameter.Parameter
+ ):
convert_to_parameter = True
# fetch patched function
if meta_patched_function.has(target):
@@ -309,10 +314,12 @@ def _meta_data_computing(self, kind, target, args, kwargs):
return meta_out
- def trace(self,
- root: nn.Module,
- concrete_args: Optional[Dict[str, Tensor]] = None,
- meta_args: Optional[Dict[str, Tensor]] = None) -> Graph:
+ def trace(
+ self,
+ root: nn.Module,
+ concrete_args: Optional[Dict[str, Tensor]] = None,
+ meta_args: Optional[Dict[str, Tensor]] = None,
+ ) -> Graph:
"""
Trace the forward computation graph using `torch.fx.Tracer`. This tracer enables data-dependent control flow.
@@ -341,9 +348,7 @@ def trace(self,
# update concrete args with default values
non_meta_arg_names = sig_names - meta_arg_names
for k, v in sig.parameters.items():
- if k in non_meta_arg_names and \
- k not in concrete_args and \
- v.default is not inspect.Parameter.empty:
+ if k in non_meta_arg_names and k not in concrete_args and v.default is not inspect.Parameter.empty:
concrete_args[k] = v.default
# get non concrete arg names
@@ -354,7 +359,8 @@ def _check_arg_name_valid(names):
success, element = is_element_in_list(names, sig_names)
if not success:
raise KeyError(
- f"argument {element} is not found in the signature of {root.__class__.__name__}'s forward function")
+ f"argument {element} is not found in the signature of {root.__class__.__name__}'s forward function"
+ )
_check_arg_name_valid(meta_arg_names)
_check_arg_name_valid(concrete_arg_names)
@@ -363,11 +369,13 @@ def _check_arg_name_valid(names):
def _check_kwargs(kwargs, should_be_meta: bool):
for k, v in kwargs.items():
if not should_be_meta:
- assert not torch.is_tensor(v) or not v.is_meta, \
- f'Expected the {k} not to be a meta tensor, please check the args passed to the tracer'
+ assert (
+ not torch.is_tensor(v) or not v.is_meta
+ ), f"Expected the {k} not to be a meta tensor, please check the args passed to the tracer"
else:
- assert v.is_meta == should_be_meta, \
- f'Expected the is_meta attribute of {k} to be {should_be_meta}, but got {v.is_meta}, please check the args passed to the tracer'
+ assert (
+ v.is_meta == should_be_meta
+ ), f"Expected the is_meta attribute of {k} to be {should_be_meta}, but got {v.is_meta}, please check the args passed to the tracer"
_check_kwargs(concrete_args, should_be_meta=False)
_check_kwargs(meta_args, should_be_meta=True)
@@ -442,7 +450,6 @@ def trace_activation_checkpoint(self, enabled: bool):
orig_ckpt_func = torch.utils.checkpoint.CheckpointFunction
class PatchedCheckpointFunction(torch.autograd.Function):
-
@staticmethod
def forward(ctx, run_function, preserve_rng_state, *args):
# signal that the current tracing occurs within activation checkpoint part
@@ -455,7 +462,8 @@ def forward(ctx, run_function, preserve_rng_state, *args):
@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
raise NotImplementedError(
- "We do not implement the backward pass as we only trace the forward pass.")
+ "We do not implement the backward pass as we only trace the forward pass."
+ )
# override the checkpoint function
torch.utils.checkpoint.CheckpointFunction = PatchedCheckpointFunction
@@ -470,12 +478,11 @@ def create_node(self, *args, **kwargs) -> Node:
if self.inside_torch_checkpoint_func:
# annotate the activation checkpoint module
- node.meta['activation_checkpoint'] = self.act_ckpt_region_count
+ node.meta["activation_checkpoint"] = self.act_ckpt_region_count
return node
def wrap_tensor_constructor_method(target):
-
def look_for_proxy(*args, **kwargs):
# find in pos vars
for arg in args:
@@ -518,12 +525,10 @@ def wrapper(*args, **kwargs):
for method in magic_methods:
def _scope(method):
-
def impl(*args, **kwargs):
-
tracer = args[0].tracer
target = getattr(operator, method)
- proxy = tracer.create_proxy('call_function', target, args, kwargs)
+ proxy = tracer.create_proxy("call_function", target, args, kwargs)
if not isinstance(proxy, ColoProxy):
meta_out = compute_meta_data_for_functions_proxy(target, args, kwargs)
proxy = ColoProxy(proxy.node)
@@ -542,7 +547,7 @@ def _define_reflectable(orig_method_name):
def impl(self, rhs):
target = getattr(operator, orig_method_name)
- proxy = self.tracer.create_proxy('call_function', target, (rhs, self), {})
+ proxy = self.tracer.create_proxy("call_function", target, (rhs, self), {})
if not isinstance(proxy, ColoProxy):
meta_out = compute_meta_data_for_functions_proxy(target, *(rhs, self), {})
proxy = ColoProxy(proxy.node)
diff --git a/colossalai/global_variables.py b/colossalai/global_variables.py
deleted file mode 100644
index 61b31965e2e6..000000000000
--- a/colossalai/global_variables.py
+++ /dev/null
@@ -1,56 +0,0 @@
-from typing import Optional
-
-
-class TensorParallelEnv(object):
- _instance = None
-
- def __new__(cls, *args, **kwargs):
- if cls._instance is None:
- cls._instance = object.__new__(cls, *args, **kwargs)
- return cls._instance
-
- def __init__(self, *args, **kwargs):
- self.load(*args, **kwargs)
-
- def load(self,
- mode: Optional[str] = None,
- vocab_parallel: bool = False,
- parallel_input_1d: bool = False,
- summa_dim: int = None,
- tesseract_dim: int = None,
- tesseract_dep: int = None,
- depth_3d: int = None,
- input_group_3d=None,
- weight_group_3d=None,
- output_group_3d=None,
- input_x_weight_group_3d=None,
- output_x_weight_group_3d=None):
- self.mode = mode
- self.vocab_parallel = vocab_parallel
- self.parallel_input_1d = parallel_input_1d
- self.summa_dim = summa_dim
- self.tesseract_dim = tesseract_dim
- self.tesseract_dep = tesseract_dep
- self.depth_3d = depth_3d
- self.input_group_3d = input_group_3d
- self.weight_group_3d = weight_group_3d
- self.output_group_3d = output_group_3d
- self.input_x_weight_group_3d = input_x_weight_group_3d
- self.output_x_weight_group_3d = output_x_weight_group_3d
-
- def save(self):
- return dict(mode=self.mode,
- vocab_parallel=self.vocab_parallel,
- parallel_input_1d=self.parallel_input_1d,
- summa_dim=self.summa_dim,
- tesseract_dim=self.tesseract_dim,
- tesseract_dep=self.tesseract_dep,
- depth_3d=self.depth_3d,
- input_group_3d=self.input_group_3d,
- weight_group_3d=self.weight_group_3d,
- output_group_3d=self.output_group_3d,
- input_x_weight_group_3d=self.input_x_weight_group_3d,
- output_x_weight_group_3d=self.output_x_weight_group_3d)
-
-
-tensor_parallel_env = TensorParallelEnv()
diff --git a/colossalai/inference/quant/gptq/__init__.py b/colossalai/inference/quant/gptq/__init__.py
new file mode 100644
index 000000000000..c035f397923a
--- /dev/null
+++ b/colossalai/inference/quant/gptq/__init__.py
@@ -0,0 +1,4 @@
+from .cai_gptq import HAS_AUTO_GPTQ
+
+if HAS_AUTO_GPTQ:
+ from .cai_gptq import CaiGPTQLinearOp, CaiQuantLinear
diff --git a/colossalai/inference/quant/gptq/cai_gptq/__init__.py b/colossalai/inference/quant/gptq/cai_gptq/__init__.py
new file mode 100644
index 000000000000..de57f2d8cfee
--- /dev/null
+++ b/colossalai/inference/quant/gptq/cai_gptq/__init__.py
@@ -0,0 +1,13 @@
+import warnings
+
+HAS_AUTO_GPTQ = False
+try:
+ import auto_gptq
+ HAS_AUTO_GPTQ = True
+except ImportError:
+ warnings.warn('please install auto-gptq from https://github.com/PanQiWei/AutoGPTQ')
+ HAS_AUTO_GPTQ = False
+
+if HAS_AUTO_GPTQ:
+ from .cai_quant_linear import CaiQuantLinear, ColCaiQuantLinear, RowCaiQuantLinear
+ from .gptq_op import CaiGPTQLinearOp
diff --git a/colossalai/inference/quant/gptq/cai_gptq/cai_quant_linear.py b/colossalai/inference/quant/gptq/cai_gptq/cai_quant_linear.py
new file mode 100644
index 000000000000..ca12c34ed958
--- /dev/null
+++ b/colossalai/inference/quant/gptq/cai_gptq/cai_quant_linear.py
@@ -0,0 +1,354 @@
+# Adapted from AutoGPTQ auto_gptq: https://github.com/PanQiWei/AutoGPTQ
+
+import math
+import warnings
+from typing import List, Union
+
+import numpy as np
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from torch.distributed import ProcessGroup
+
+from colossalai.lazy import LazyInitContext
+from colossalai.shardformer.layer import ParallelModule
+
+from .gptq_op import CaiGPTQLinearOp
+
+HAS_GPTQ_CUDA = False
+try:
+ from colossalai.kernel.op_builder.gptq import GPTQBuilder
+ gptq_cuda = GPTQBuilder().load()
+ HAS_GPTQ_CUDA = True
+except ImportError:
+ warnings.warn('CUDA gptq is not installed')
+ HAS_GPTQ_CUDA = False
+
+
+class CaiQuantLinear(nn.Module):
+
+ def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
+ super().__init__()
+ if bits not in [2, 4, 8]:
+ raise NotImplementedError("Only 2,4,8 bits are supported.")
+ self.infeatures = infeatures
+ self.outfeatures = outfeatures
+ self.bits = bits
+ self.maxq = 2**self.bits - 1
+ self.groupsize = groupsize if groupsize != -1 else infeatures
+
+ self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32))
+ self.register_buffer(
+ 'qzeros',
+ torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32))
+ self.register_buffer('scales',
+ torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16))
+ if row_split:
+ self.register_buffer(
+ 'g_idx',
+ torch.tensor([(i + (tp_rank * self.infeatures)) // self.groupsize for i in range(infeatures)],
+ dtype=torch.int32))
+ else:
+ self.register_buffer('g_idx',
+ torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32))
+
+ if bias:
+ self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
+ else:
+ self.bias = None
+
+ self.gptq_linear = CaiGPTQLinearOp(groupsize, bits)
+
+ self.q4 = None
+ self.empty_tensor = torch.empty((1, 1), device="meta")
+ self.tp_size = tp_size
+ self.tp_rank = tp_rank
+ self.row_split = row_split
+
+ def pack(self, linear, scales, zeros, g_idx=None):
+
+ g_idx = g_idx.clone() if g_idx is not None else torch.tensor(
+ [i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32)
+
+ scales = scales.t().contiguous()
+ zeros = zeros.t().contiguous()
+ scale_zeros = zeros * scales
+ half_scales = scales.clone().half()
+ # print("scale shape ", scales.shape, scale_zeros.shape, linear.weight.shape)
+ self.scales = scales.clone().half()
+ if linear.bias is not None:
+ self.bias = linear.bias.clone().half()
+
+ wn = 8
+ pbits = 32
+ ptype = torch.int32
+ unsign_type = np.uint32
+ sign_type = np.int32
+
+ intweight = []
+ for idx in range(self.infeatures):
+ intweight.append(
+ torch.round(
+ (linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[:,
+ None])
+ intweight = torch.cat(intweight, dim=1)
+ intweight = intweight.t().contiguous()
+ intweight = intweight.numpy().astype(unsign_type)
+ qweight = np.zeros((intweight.shape[0] // pbits * self.bits, intweight.shape[1]), dtype=unsign_type)
+
+ i = 0
+ row = 0
+
+ while row < qweight.shape[0]:
+ if self.bits in [2, 4, 8]:
+ for j in range(i, i + (pbits // self.bits)):
+ qweight[row] |= intweight[j] << (self.bits * (j - i))
+ i += pbits // self.bits
+ row += 1
+ else:
+ raise NotImplementedError("Only 2,4,8 bits are supported.")
+ qweight = qweight.astype(sign_type)
+ qweight1 = torch.from_numpy(qweight)
+ qweight1 = qweight1.contiguous() #.to("cuda")
+ self.qweight.data.copy_(qweight1)
+
+ qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // pbits * self.bits), dtype=unsign_type)
+ zeros -= 1
+ zeros = zeros.numpy().astype(unsign_type)
+ i = 0
+ col = 0
+ while col < qzeros.shape[1]:
+ if self.bits in [2, 4, 8]:
+ for j in range(i, i + (pbits // self.bits)):
+ qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
+ i += pbits // self.bits
+ col += 1
+ else:
+ raise NotImplementedError("Only 2,4,8 bits are supported.")
+ qzeros = qzeros.astype(sign_type)
+ qzeros = torch.from_numpy(qzeros)
+ qzeros = qzeros
+ self.qzeros.data.copy_(qzeros)
+
+ if torch.equal(self.g_idx.to(g_idx.device), g_idx):
+ self.g_idx = None
+ else:
+ self.g_idx = g_idx
+
+ def init_q4(self):
+ assert self.qweight.device.type == "cuda"
+ self.q4_width = self.qweight.shape[1]
+ if self.g_idx is not None:
+ if self.row_split and torch.equal(
+ self.g_idx,
+ torch.tensor(
+ [(i + (self.tp_rank * self.infeatures)) // self.groupsize for i in range(self.infeatures)],
+ dtype=torch.int32,
+ device=self.g_idx.device)):
+ self.g_idx = None
+ elif torch.equal(
+ self.g_idx,
+ torch.tensor([i // self.groupsize for i in range(self.infeatures)],
+ dtype=torch.int32,
+ device=self.g_idx.device)):
+ self.g_idx = None
+
+ if self.g_idx is not None:
+ g_idx = self.g_idx.to("cpu")
+ else:
+ g_idx = self.empty_tensor
+
+ self.q4 = gptq_cuda.make_q4(self.qweight, self.qzeros, self.scales, g_idx, torch.cuda.current_device())
+ torch.cuda.synchronize()
+
+ def forward(self, x):
+ outshape = x.shape[:-1] + (self.outfeatures,)
+
+ if HAS_GPTQ_CUDA and self.bits == 4:
+
+ if self.q4 is None:
+ self.init_q4()
+
+ x = x.view(-1, x.shape[-1])
+ output = torch.empty((x.shape[0], self.outfeatures), dtype=torch.float16, device=x.device)
+ gptq_cuda.q4_matmul(x.half(), self.q4, output)
+ if self.bias is not None and (not self.row_split or self.tp_size == 1):
+ output.add_(self.bias)
+ else:
+ if self.bias is not None and (not self.row_split or self.tp_size == 1):
+ bias = self.bias
+ else:
+ bias = None
+ output = self.gptq_linear(
+ x,
+ self.qweight,
+ self.scales,
+ self.qzeros,
+ g_idx=self.g_idx,
+ bias=bias,
+ )
+ return output.view(outshape)
+
+
+def split_column_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1):
+
+ qweights = gptq_linear.qweight.split(gptq_linear.out_features // split_num, dim=-1)
+ qzeros = gptq_linear.qzeros.split(gptq_linear.out_features // (32 // cai_linear.bits) // split_num, dim=-1)
+ scales = gptq_linear.scales.split(gptq_linear.out_features // split_num, dim=-1)
+ g_idx = gptq_linear.g_idx
+ if gptq_linear.bias is not None:
+ bias = gptq_linear.bias.split(gptq_linear.out_features // split_num, dim=-1)
+
+ cai_split_out_features = cai_linear.outfeatures // split_num
+ zero_split_block = cai_linear.outfeatures // (32 // cai_linear.bits) // split_num
+
+ for i in range(split_num):
+ cai_linear.qweight[:, i * cai_split_out_features:(i + 1) *
+ cai_split_out_features] = qweights[i][:, tp_rank * cai_split_out_features:(tp_rank + 1) *
+ cai_split_out_features]
+ cai_linear.qzeros[:, i * zero_split_block:(i + 1) *
+ zero_split_block] = qzeros[i][:, tp_rank * zero_split_block:(tp_rank + 1) * zero_split_block]
+ cai_linear.scales[:, i * cai_split_out_features:(i + 1) *
+ cai_split_out_features] = scales[i][:, tp_rank * cai_split_out_features:(tp_rank + 1) *
+ cai_split_out_features]
+ if cai_linear.bias is not None:
+ cai_linear.bias[i * cai_split_out_features:(i + 1) *
+ cai_split_out_features] = bias[i][tp_rank * cai_split_out_features:(tp_rank + 1) *
+ cai_split_out_features]
+
+ cai_linear.g_idx.copy_(g_idx)
+
+
+def split_row_copy(gptq_linear, cai_linear, tp_rank=0, split_num=1):
+
+ qweights = gptq_linear.qweight.split(gptq_linear.in_features // split_num, dim=0)
+ qzeros = gptq_linear.qzeros.split(gptq_linear.in_features // split_num, dim=0)
+ scales = gptq_linear.scales.split(gptq_linear.in_features // split_num, dim=0)
+ g_idxs = gptq_linear.g_idx.split(gptq_linear.in_features // split_num, dim=0)
+
+ cai_split_in_features = cai_linear.infeatures // (32 // cai_linear.bits) // split_num
+ zero_split_block = cai_linear.infeatures // cai_linear.groupsize // split_num
+ idx_split_features = cai_linear.infeatures // split_num
+
+ for i in range(split_num):
+ cai_linear.qweight[i * cai_split_in_features:(i + 1) *
+ cai_split_in_features, :] = qweights[i][tp_rank * cai_split_in_features:(tp_rank + 1) *
+ cai_split_in_features, :]
+ cai_linear.qzeros[i * zero_split_block:(i + 1) *
+ zero_split_block, :] = qzeros[i][tp_rank * zero_split_block:(tp_rank + 1) *
+ zero_split_block, :]
+ cai_linear.scales[i * zero_split_block:(i + 1) *
+ zero_split_block, :] = scales[i][tp_rank * zero_split_block:(tp_rank + 1) *
+ zero_split_block, :]
+ cai_linear.g_idx[i * idx_split_features:(i + 1) *
+ idx_split_features] = g_idxs[i][tp_rank * idx_split_features:(tp_rank + 1) *
+ idx_split_features]
+ if cai_linear.bias is not None:
+ cai_linear.bias.copy_(gptq_linear.bias)
+
+
+class RowCaiQuantLinear(CaiQuantLinear, ParallelModule):
+
+ def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
+
+ super().__init__(bits,
+ groupsize,
+ infeatures,
+ outfeatures,
+ bias,
+ tp_size=tp_size,
+ tp_rank=tp_rank,
+ row_split=row_split)
+ self.process_group = None
+
+ @staticmethod
+ def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
+ **kwargs) -> ParallelModule:
+ LazyInitContext.materialize(module)
+ # get the attributes
+ in_features = module.in_features
+
+ # ensure only one process group is passed
+ if isinstance(process_group, (list, tuple)):
+ assert len(process_group) == 1, \
+ f'Expected only one process group, got {len(process_group)}.'
+ process_group = process_group[0]
+
+ tp_size = dist.get_world_size(process_group)
+ tp_rank = dist.get_rank(process_group)
+
+ if in_features < tp_size:
+ return module
+
+ if in_features % tp_size != 0:
+ raise ValueError(
+ f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!")
+ linear_1d = RowCaiQuantLinear(module.bits,
+ module.group_size,
+ module.in_features // tp_size,
+ module.out_features,
+ module.bias is not None,
+ tp_size=tp_size,
+ tp_rank=tp_rank,
+ row_split=True)
+ linear_1d.process_group = process_group
+
+ split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
+ return linear_1d
+
+ def forward(self, x):
+ output = super().forward(x)
+ if self.tp_size > 1:
+ dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group)
+ if self.bias is not None:
+ output.add_(self.bias)
+ return output
+
+
+class ColCaiQuantLinear(CaiQuantLinear, ParallelModule):
+
+ def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
+
+ super().__init__(bits,
+ groupsize,
+ infeatures,
+ outfeatures,
+ bias,
+ tp_size=tp_size,
+ tp_rank=tp_rank,
+ row_split=row_split)
+ self.process_group = None
+
+ @staticmethod
+ def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
+ **kwargs) -> ParallelModule:
+ LazyInitContext.materialize(module)
+ # get the attributes
+ in_features = module.in_features
+
+ # ensure only one process group is passed
+ if isinstance(process_group, (list, tuple)):
+ assert len(process_group) == 1, \
+ f'Expected only one process group, got {len(process_group)}.'
+ process_group = process_group[0]
+
+ tp_size = dist.get_world_size(process_group)
+ tp_rank = dist.get_rank(process_group)
+
+ if in_features < tp_size:
+ return module
+
+ if in_features % tp_size != 0:
+ raise ValueError(
+ f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!")
+ linear_1d = ColCaiQuantLinear(module.bits,
+ module.group_size,
+ module.in_features,
+ module.out_features // tp_size,
+ module.bias is not None,
+ tp_size=tp_size,
+ tp_rank=tp_rank)
+ linear_1d.process_group = process_group
+
+ split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
+ return linear_1d
diff --git a/colossalai/inference/quant/gptq/cai_gptq/gptq_op.py b/colossalai/inference/quant/gptq/cai_gptq/gptq_op.py
new file mode 100644
index 000000000000..a8902eb35cd0
--- /dev/null
+++ b/colossalai/inference/quant/gptq/cai_gptq/gptq_op.py
@@ -0,0 +1,58 @@
+import torch
+
+from colossalai.kernel.triton import gptq_fused_linear_triton
+
+
+class CaiGPTQLinearOp(torch.nn.Module):
+ def __init__(self, gptq_group_size, gptq_quant_bits):
+ super(CaiGPTQLinearOp, self).__init__()
+ self.group_size = gptq_group_size
+ self.bits = gptq_quant_bits
+ self.maxq = 2**self.bits - 1
+ self.empty_tensor = torch.zeros(4, device=torch.cuda.current_device())
+
+ def forward(
+ self,
+ input: torch.Tensor,
+ weight: torch.Tensor,
+ weight_scales: torch.Tensor,
+ weight_zeros: torch.Tensor,
+ g_idx: torch.Tensor = None,
+ act_type=0,
+ bias: torch.Tensor = None,
+ residual: torch.Tensor = None,
+ qkv_fused=False,
+ ):
+ add_bias = True
+ if bias is None:
+ bias = self.empty_tensor
+ add_bias = False
+
+ add_residual = True
+ if residual is None:
+ residual = self.empty_tensor
+ add_residual = False
+ x = input.view(-1, input.shape[-1])
+
+ out = gptq_fused_linear_triton(
+ x,
+ weight,
+ weight_scales,
+ weight_zeros,
+ bias,
+ residual,
+ self.bits,
+ self.maxq,
+ self.group_size,
+ qkv_fused,
+ add_bias,
+ add_residual,
+ act_type=act_type,
+ g_idx=g_idx,
+ )
+ if qkv_fused:
+ out = out.view(3, input.shape[0], input.shape[1], weight.shape[-1])
+ else:
+ out = out.view(input.shape[0], input.shape[1], weight.shape[-1])
+
+ return out
diff --git a/colossalai/inference/tensor_parallel/__init__.py b/colossalai/inference/tensor_parallel/__init__.py
index e467b4c73e6b..112b920ba158 100644
--- a/colossalai/inference/tensor_parallel/__init__.py
+++ b/colossalai/inference/tensor_parallel/__init__.py
@@ -1,4 +1,4 @@
from .engine import TPInferEngine
from .kvcache_manager import MemoryManager
-__all__ = ['MemoryManager', 'TPInferEngine']
+__all__ = ["MemoryManager", "TPInferEngine"]
diff --git a/colossalai/inference/tensor_parallel/batch_infer_state.py b/colossalai/inference/tensor_parallel/batch_infer_state.py
index 2bff9317283e..ac185f1b6529 100644
--- a/colossalai/inference/tensor_parallel/batch_infer_state.py
+++ b/colossalai/inference/tensor_parallel/batch_infer_state.py
@@ -1,6 +1,5 @@
# might want to consider combine with InferenceConfig in colossalai/ppinference/inference_config.py later
from dataclasses import dataclass
-from typing import Any
import torch
@@ -31,7 +30,7 @@ class BatchInferState:
decode_mem_index: torch.Tensor = None
decode_layer_id: int = None
- device: torch.device = torch.device('cuda')
+ device: torch.device = torch.device("cuda")
@property
def total_token_num(self):
@@ -43,13 +42,15 @@ def set_cache_manager(self, manager: MemoryManager):
self.cache_manager = manager
@staticmethod
- def init_block_loc(b_loc: torch.Tensor, seq_len: torch.Tensor, max_len_in_batch: int,
- alloc_mem_index: torch.Tensor):
- """ in-place update block loc mapping based on the sequence length of the inputs in current bath"""
+ def init_block_loc(
+ b_loc: torch.Tensor, seq_len: torch.Tensor, max_len_in_batch: int, alloc_mem_index: torch.Tensor
+ ):
+ """in-place update block loc mapping based on the sequence length of the inputs in current bath"""
start_index = 0
seq_len_numpy = seq_len.cpu().numpy()
for i, cur_seq_len in enumerate(seq_len_numpy):
- b_loc[i, max_len_in_batch - cur_seq_len:max_len_in_batch] = alloc_mem_index[start_index:start_index +
- cur_seq_len]
+ b_loc[i, max_len_in_batch - cur_seq_len : max_len_in_batch] = alloc_mem_index[
+ start_index : start_index + cur_seq_len
+ ]
start_index += cur_seq_len
return
diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py
index a5a55702ade0..d5ef37fee420 100644
--- a/colossalai/inference/tensor_parallel/engine.py
+++ b/colossalai/inference/tensor_parallel/engine.py
@@ -1,6 +1,7 @@
-from typing import Any, Callable, Dict, List, Optional, Union
+from typing import Any, Callable, List, Optional, Union
import torch
+import torch.distributed as dist
import torch.nn as nn
from transformers import BloomForCausalLM, LlamaForCausalLM
from transformers.generation import GenerationConfig
@@ -15,7 +16,13 @@
DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2
-_supported_models = ['LlamaForCausalLM', 'LlamaModel', 'BloomForCausalLM']
+_supported_models = [
+ "LlamaForCausalLM",
+ "LlamaModel",
+ "BloomForCausalLM",
+ "ChatGLMModel",
+ "ChatGLMForConditionalGeneration",
+]
class TPInferEngine:
@@ -39,14 +46,16 @@ class TPInferEngine:
>>> outputs = infer_engine.generate(input_ids, **generate_kwargs)
"""
- def __init__(self,
- model: nn.Module,
- shard_config: ShardConfig,
- max_batch_size: int,
- max_input_len: int,
- max_output_len: int,
- dtype: torch.dtype = torch.float16,
- device: str = 'cuda') -> None:
+ def __init__(
+ self,
+ model: nn.Module,
+ shard_config: ShardConfig,
+ max_batch_size: int,
+ max_input_len: int,
+ max_output_len: int,
+ dtype: torch.dtype = torch.float16,
+ device: str = "cuda",
+ ) -> None:
self.max_batch_size = max_batch_size
self.max_input_len = max_input_len
self.max_output_len = max_output_len
@@ -61,11 +70,24 @@ def __init__(self,
self.head_dim = model.config.hidden_size // model.config.num_attention_heads
self.head_num = model.config.num_attention_heads
- self.layer_num = model.config.num_hidden_layers
-
- self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config
+ num_hidden_layers = (
+ model.config.num_hidden_layers if hasattr(model.config, "num_hidden_layers") else model.config.num_layers
+ )
+ self.layer_num = num_hidden_layers
+ self.multi_query_group_num = (
+ model.config.multi_query_group_num if hasattr(model.config, "multi_query_group_num") else 0
+ )
+
+ self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config
self.cache_manager = None
+ self.max_dq_buffer_size = 1
+ self.max_inner_outer_dim = 1
+ self.gptq_temp_state_buffer = None
+ self.gptq_temp_dq_buffer = None
+ self.bits = -1
+ self.use_act_order = False
+
self.shard_config = shard_config
self.model = None
# optimize the original model by sharding with ShardFormer
@@ -74,9 +96,67 @@ def __init__(self,
def _init_manager(self) -> None:
assert self.tp_size >= 1, "TP size not initialized without providing a valid ShardConfig"
assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}"
- self.head_num //= self.tp_size # update sharded number of heads
- self.cache_manager = MemoryManager(self.max_total_token_num, self.dtype, self.head_num, self.head_dim,
- self.layer_num)
+ self.head_num //= self.tp_size # update sharded number of heads
+ if self.multi_query_group_num:
+ # NOTE the logic of MQA tensor parallelism should be specified.
+ assert (
+ self.multi_query_group_num % self.tp_size == 0
+ ), f"Cannot shard {self.multi_query_group_num} query groups with tp size {self.tp_size}"
+ self.cache_manager = MemoryManager(
+ self.max_total_token_num,
+ self.dtype,
+ self.multi_query_group_num // self.tp_size,
+ self.head_dim,
+ self.layer_num,
+ )
+ else:
+ self.cache_manager = MemoryManager(
+ self.max_total_token_num, self.dtype, self.head_num, self.head_dim, self.layer_num
+ )
+
+ def _post_init_gptq_buffer(self, model: nn.Module) -> None:
+ from colossalai.inference.quant.gptq.cai_gptq import CaiQuantLinear
+ HAS_GPTQ_CUDA = False
+ try:
+ from colossalai.kernel.op_builder.gptq import GPTQBuilder
+ gptq_cuda = GPTQBuilder().load()
+ HAS_GPTQ_CUDA = True
+ except ImportError:
+ warnings.warn('CUDA gptq is not installed')
+ HAS_GPTQ_CUDA = False
+
+ for name, submodule in model.named_modules():
+ if isinstance(submodule, CaiQuantLinear):
+ self.max_dq_buffer_size = max(self.max_dq_buffer_size, submodule.qweight.numel() * 8)
+
+ if self.use_act_order:
+ self.max_inner_outer_dim = max(self.max_inner_outer_dim, submodule.infeatures,
+ submodule.outfeatures)
+ self.bits = submodule.bits
+ if not (HAS_GPTQ_CUDA and self.bits == 4):
+ return
+
+ max_input_len = 1
+ if self.use_act_order:
+ max_input_len = self.max_input_len
+ # The temp_state buffer is required to reorder X in the act-order case.
+ # The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
+ self.gptq_temp_state_buffer = torch.zeros((max_input_len, self.max_inner_outer_dim),
+ dtype=torch.float16,
+ device=torch.cuda.current_device())
+ self.gptq_temp_dq_buffer = torch.zeros((1, self.max_dq_buffer_size),
+ dtype=torch.float16,
+ device=torch.cuda.current_device())
+
+ gptq_cuda.prepare_buffers(torch.device(torch.cuda.current_device()), self.gptq_temp_state_buffer,
+ self.gptq_temp_dq_buffer)
+ # Using the default from exllama repo here.
+ matmul_recons_thd = 8
+ matmul_fused_remap = False
+ matmul_no_half2 = False
+ gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
+
+ torch.cuda.empty_cache()
def _optimize_model(self, model: nn.Module) -> None:
"""
@@ -90,7 +170,7 @@ def _optimize_model(self, model: nn.Module) -> None:
self._shard_model_by(shardformer, model)
def _prepare_with_shard_config(self, shard_config: Optional[ShardConfig] = None) -> ShardConfig:
- """ Prepare the engine with a given ShardConfig.
+ """Prepare the engine with a given ShardConfig.
Args:
shard_config (ShardConfig): shard config given to specify settings of the engine.
@@ -118,13 +198,18 @@ def _prepare_with_shard_config(self, shard_config: Optional[ShardConfig] = None)
return shard_config
def _shard_model_by(self, shardformer: ShardFormer, model: nn.Module) -> None:
- """ Shard original model by the given ShardFormer and store the sharded model. """
- assert self.tp_size == shardformer.shard_config.tensor_parallel_size, \
- "Discrepancy between the tp size of TPInferEngine and the tp size of shard config"
+ """Shard original model by the given ShardFormer and store the sharded model."""
+ assert (
+ self.tp_size == shardformer.shard_config.tensor_parallel_size
+ ), "Discrepancy between the tp size of TPInferEngine and the tp size of shard config"
model_name = model.__class__.__name__
assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference."
policy = get_autopolicy(model, inference_only=True)
self.model, _ = shardformer.optimize(model, policy)
+
+ if self.shard_config.inference_gptq:
+ self._post_init_gptq_buffer(model)
+
self.model = self.model.cuda()
@property
@@ -147,7 +232,7 @@ def generate(self, input_tokens: Union[BatchEncoding, dict, list, torch.Tensor],
for t in input_tokens:
if torch.is_tensor(input_tokens[t]):
input_tokens[t] = input_tokens[t].cuda()
- if 'max_new_tokens' not in generate_kwargs:
+ if "max_new_tokens" not in generate_kwargs:
generate_kwargs.update(max_new_tokens=self.max_output_len)
return self._generate_by_set_infer_state(input_tokens, **generate_kwargs)
@@ -176,18 +261,18 @@ def prepare_batch_state(self, inputs) -> BatchInferState:
attention_mask = None
if isinstance(inputs, (BatchEncoding, dict)):
- input_ids_list = inputs['input_ids']
- attention_mask = inputs['attention_mask']
+ input_ids_list = inputs["input_ids"]
+ attention_mask = inputs["attention_mask"]
else:
input_ids_list = inputs
- if isinstance(input_ids_list[0], int): # for a single input
+ if isinstance(input_ids_list[0], int): # for a single input
input_ids_list = [input_ids_list]
attention_mask = [attention_mask] if attention_mask is not None else attention_mask
batch_size = len(input_ids_list)
- seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device='cuda')
- seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device='cuda')
+ seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
+ seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
start_index = 0
max_len_in_batch = -1
@@ -210,10 +295,10 @@ def prepare_batch_state(self, inputs) -> BatchInferState:
seq_start_indexes[i] = start_index
start_index += curr_seq_len
max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch
- block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device='cuda')
+ block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device="cuda")
batch_infer_state = BatchInferState(batch_size, max_len_in_batch)
- batch_infer_state.seq_len = seq_lengths.to('cuda')
- batch_infer_state.start_loc = seq_start_indexes.to('cuda')
+ batch_infer_state.seq_len = seq_lengths.to("cuda")
+ batch_infer_state.start_loc = seq_start_indexes.to("cuda")
batch_infer_state.block_loc = block_loc
batch_infer_state.decode_layer_id = 0
batch_infer_state.past_key_values_len = 0
@@ -248,7 +333,7 @@ def _generate_by_set_infer_state(self, input_tokens, **generate_kwargs) -> torch
model = self.model.model
elif isinstance(model, BloomForCausalLM):
model = self.model.transformer
- setattr(model, 'infer_state', batch_infer_state)
+ setattr(model, "infer_state", batch_infer_state)
outputs = self.model.generate(**input_tokens, **generate_kwargs, early_stopping=False)
@@ -262,14 +347,15 @@ def _generate_by_set_infer_state(self, input_tokens, **generate_kwargs) -> torch
# as an arg into model.forward.
# It requires rewriting model generate and replacing model forward.
@torch.no_grad()
- def _generate_by_pass_infer_state(self,
- input_tokens,
- max_out_length: int,
- generation_config: Optional[GenerationConfig] = None,
- stopping_criteria: Optional[StoppingCriteriaList] = None,
- prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
- **model_kwargs) -> torch.Tensor:
-
+ def _generate_by_pass_infer_state(
+ self,
+ input_tokens,
+ max_out_length: int,
+ generation_config: Optional[GenerationConfig] = None,
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
+ prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
+ **model_kwargs,
+ ) -> torch.Tensor:
raise NotImplementedError("generate by passing BatchInferState is not implemented.")
# might want to use in rewritten generate method: use after model.forward
diff --git a/colossalai/inference/tensor_parallel/kvcache_manager.py b/colossalai/inference/tensor_parallel/kvcache_manager.py
index 274c01841279..e74a3a491a7b 100644
--- a/colossalai/inference/tensor_parallel/kvcache_manager.py
+++ b/colossalai/inference/tensor_parallel/kvcache_manager.py
@@ -19,13 +19,15 @@ class MemoryManager:
device: device used to store the key and value cache
"""
- def __init__(self,
- size: int,
- dtype: torch.dtype,
- head_num: int,
- head_dim: int,
- layer_num: int,
- device: torch.device = torch.device('cuda')):
+ def __init__(
+ self,
+ size: int,
+ dtype: torch.dtype,
+ head_num: int,
+ head_dim: int,
+ layer_num: int,
+ device: torch.device = torch.device("cuda"),
+ ):
self.logger = logging.get_logger(__name__)
self.available_size = size
self.past_key_values_length = 0
@@ -33,13 +35,13 @@ def __init__(self,
self._init_kv_buffers(size, device, dtype, head_num, head_dim, layer_num)
def _init_mem_states(self, size, device):
- """ Initialize tensors used to manage memory states """
+ """Initialize tensors used to manage memory states"""
self.mem_state = torch.ones((size,), dtype=torch.bool, device=device)
self.mem_cum_sum = torch.empty((size,), dtype=torch.int32, device=device)
self.indexes = torch.arange(0, size, dtype=torch.long, device=device)
def _init_kv_buffers(self, size, device, dtype, head_num, head_dim, layer_num):
- """ Initialize key buffer and value buffer on specified device """
+ """Initialize key buffer and value buffer on specified device"""
self.key_buffer = [
torch.empty((size, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num)
]
@@ -49,10 +51,9 @@ def _init_kv_buffers(self, size, device, dtype, head_num, head_dim, layer_num):
@torch.no_grad()
def alloc(self, required_size):
- """ allocate space of required_size by providing indexes representing available physical spaces """
+ """allocate space of required_size by providing indexes representing available physical spaces"""
if required_size > self.available_size:
- self.logger.warning(f"No enough cache: required_size {required_size} "
- f"left_size {self.available_size}")
+ self.logger.warning(f"No enough cache: required_size {required_size} " f"left_size {self.available_size}")
return None
torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum)
select_index = torch.logical_and(self.mem_cum_sum <= required_size, self.mem_state == 1)
@@ -63,23 +64,25 @@ def alloc(self, required_size):
@torch.no_grad()
def alloc_contiguous(self, required_size):
- """ allocate contiguous space of required_size """
+ """allocate contiguous space of required_size"""
if required_size > self.available_size:
- self.logger.warning(f"No enough cache: required_size {required_size} "
- f"left_size {self.available_size}")
+ self.logger.warning(f"No enough cache: required_size {required_size} " f"left_size {self.available_size}")
return None
torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum)
sum_size = len(self.mem_cum_sum)
- loc_sums = self.mem_cum_sum[required_size - 1:] - self.mem_cum_sum[0:sum_size - required_size +
- 1] + self.mem_state[0:sum_size -
- required_size + 1]
- can_used_loc = self.indexes[0:sum_size - required_size + 1][loc_sums == required_size]
+ loc_sums = (
+ self.mem_cum_sum[required_size - 1 :]
+ - self.mem_cum_sum[0 : sum_size - required_size + 1]
+ + self.mem_state[0 : sum_size - required_size + 1]
+ )
+ can_used_loc = self.indexes[0 : sum_size - required_size + 1][loc_sums == required_size]
if can_used_loc.shape[0] == 0:
- self.logger.info(f"No enough contiguous cache: required_size {required_size} "
- f"left_size {self.available_size}")
+ self.logger.info(
+ f"No enough contiguous cache: required_size {required_size} " f"left_size {self.available_size}"
+ )
return None
start_loc = can_used_loc[0]
- select_index = self.indexes[start_loc:start_loc + required_size]
+ select_index = self.indexes[start_loc : start_loc + required_size]
self.mem_state[select_index] = 0
self.available_size -= len(select_index)
start = start_loc.item()
@@ -88,13 +91,13 @@ def alloc_contiguous(self, required_size):
@torch.no_grad()
def free(self, free_index):
- """ free memory by updating memory states based on given indexes """
+ """free memory by updating memory states based on given indexes"""
self.available_size += free_index.shape[0]
self.mem_state[free_index] = 1
@torch.no_grad()
def free_all(self):
- """ free all memory by updating memory states """
+ """free all memory by updating memory states"""
self.available_size = len(self.mem_state)
self.mem_state[:] = 1
self.past_key_values_length = 0
diff --git a/colossalai/inference/tensor_parallel/modeling/__init__.py b/colossalai/inference/tensor_parallel/modeling/__init__.py
index 7a98b033f37e..279b54065eed 100644
--- a/colossalai/inference/tensor_parallel/modeling/__init__.py
+++ b/colossalai/inference/tensor_parallel/modeling/__init__.py
@@ -1,4 +1,7 @@
+import _utils
+
from .bloom import BloomInferenceForwards
+from .chatglm2 import ChatGLM2InferenceForwards
from .llama import LlamaInferenceForwards
-__all__ = ['BloomInferenceForwards', 'LlamaInferenceForwards']
+__all__ = ["BloomInferenceForwards", "LlamaInferenceForwards", "ChatGLM2InferenceForwards"]
diff --git a/colossalai/inference/tensor_parallel/modeling/_utils.py b/colossalai/inference/tensor_parallel/modeling/_utils.py
new file mode 100644
index 000000000000..cee418707617
--- /dev/null
+++ b/colossalai/inference/tensor_parallel/modeling/_utils.py
@@ -0,0 +1,10 @@
+"""
+Utils for model inference
+"""
+from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest
+
+
+def copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager):
+ copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id])
+ copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id])
+ return
diff --git a/colossalai/inference/tensor_parallel/modeling/bloom.py b/colossalai/inference/tensor_parallel/modeling/bloom.py
index 9768fc425628..27a26caabefa 100644
--- a/colossalai/inference/tensor_parallel/modeling/bloom.py
+++ b/colossalai/inference/tensor_parallel/modeling/bloom.py
@@ -1,6 +1,6 @@
import math
import warnings
-from typing import List, Optional, Tuple, Union
+from typing import Optional, Tuple, Union
import torch
import torch.distributed as dist
@@ -17,9 +17,7 @@
from transformers.utils import logging
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
-from colossalai.kernel.triton.context_attention import bloom_context_attn_fwd
-from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest
-from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd
+from colossalai.kernel.triton import bloom_context_attn_fwd, copy_kv_cache_to_dest, token_attention_fwd
def generate_alibi(n_head, dtype=torch.float16):
@@ -33,17 +31,17 @@ def generate_alibi(n_head, dtype=torch.float16):
"""
def get_slopes_power_of_2(n):
- start = 2**(-(2**-(math.log2(n) - 3)))
+ start = 2 ** (-(2 ** -(math.log2(n) - 3)))
return [start * start**i for i in range(n)]
def get_slopes(n):
if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
else:
- closest_power_of_2 = 2**math.floor(math.log2(n))
+ closest_power_of_2 = 2 ** math.floor(math.log2(n))
slopes_power_of_2 = get_slopes_power_of_2(closest_power_of_2)
slopes_double = get_slopes(2 * closest_power_of_2)
- slopes_combined = slopes_power_of_2 + slopes_double[0::2][:n - closest_power_of_2]
+ slopes_combined = slopes_power_of_2 + slopes_double[0::2][: n - closest_power_of_2]
return slopes_combined
slopes = get_slopes(n_head)
@@ -74,7 +72,6 @@ def bloom_model_forward(
infer_state: Optional[BatchInferState] = None,
**deprecated_arguments,
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
-
logger = logging.get_logger(__name__)
if deprecated_arguments.pop("position_ids", False) is not False:
@@ -88,8 +85,9 @@ def bloom_model_forward(
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (output_hidden_states
- if output_hidden_states is not None else self.config.output_hidden_states)
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
@@ -124,14 +122,15 @@ def bloom_model_forward(
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
use_cache = False
# NOTE determine if BatchInferState is passed in via arg
# if not, get the attr binded to the model
# We might wantto remove setattr later
if infer_state is None:
- assert hasattr(self, 'infer_state')
+ assert hasattr(self, "infer_state")
infer_state = self.infer_state
# Compute alibi tensor: check build_alibi_tensor documentation
@@ -148,10 +147,11 @@ def bloom_model_forward(
if use_cache and seq_length != 1:
# prefill stage
- infer_state.is_context_stage = True # set prefill stage, notify attention layer
+ infer_state.is_context_stage = True # set prefill stage, notify attention layer
infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num)
- BatchInferState.init_block_loc(infer_state.block_loc, infer_state.seq_len, seq_length,
- infer_state.context_mem_index)
+ BatchInferState.init_block_loc(
+ infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index
+ )
else:
infer_state.is_context_stage = False
alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size)
@@ -184,8 +184,11 @@ def bloom_model_forward(
# alibi = generate_alibi(self.num_heads).contiguous().cuda()
tp_size = dist.get_world_size()
curr_tp_rank = dist.get_rank()
- alibi = generate_alibi(self.num_heads * tp_size).contiguous()[curr_tp_rank * self.num_heads:(curr_tp_rank + 1) *
- self.num_heads].cuda()
+ alibi = (
+ generate_alibi(self.num_heads * tp_size)
+ .contiguous()[curr_tp_rank * self.num_heads : (curr_tp_rank + 1) * self.num_heads]
+ .cuda()
+ )
causal_mask = self._prepare_attn_mask(
attention_mask,
input_shape=(batch_size, seq_length),
@@ -199,7 +202,6 @@ def bloom_model_forward(
if self.gradient_checkpointing and self.training:
# NOTE: currently our KV cache manager does not handle this condition
def create_custom_forward(module):
-
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
@@ -252,32 +254,34 @@ def custom_forward(*inputs):
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
- past_key_values=presents, # should always be (None, None, ..., None)
+ past_key_values=presents, # should always be (None, None, ..., None)
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
@staticmethod
- def bloom_for_causal_lm_forward(self: BloomForCausalLM,
- input_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
- attention_mask: Optional[torch.Tensor] = None,
- head_mask: Optional[torch.Tensor] = None,
- inputs_embeds: Optional[torch.Tensor] = None,
- labels: Optional[torch.Tensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- infer_state: Optional[BatchInferState] = None,
- **deprecated_arguments):
+ def bloom_for_causal_lm_forward(
+ self: BloomForCausalLM,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ infer_state: Optional[BatchInferState] = None,
+ **deprecated_arguments,
+ ):
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
- logger = logging.get_logger(__name__)
+ logging.get_logger(__name__)
if deprecated_arguments.pop("position_ids", False) is not False:
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
@@ -291,17 +295,19 @@ def bloom_for_causal_lm_forward(self: BloomForCausalLM,
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- transformer_outputs = BloomInferenceForwards.bloom_model_forward(self.transformer,
- input_ids,
- past_key_values=past_key_values,
- attention_mask=attention_mask,
- head_mask=head_mask,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- infer_state=infer_state)
+ transformer_outputs = BloomInferenceForwards.bloom_model_forward(
+ self.transformer,
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ infer_state=infer_state,
+ )
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
@@ -316,8 +322,9 @@ def bloom_for_causal_lm_forward(self: BloomForCausalLM,
batch_size, seq_length, vocab_size = shift_logits.shape
# Flatten the tokens
loss_fct = CrossEntropyLoss()
- loss = loss_fct(shift_logits.view(batch_size * seq_length, vocab_size),
- shift_labels.view(batch_size * seq_length))
+ loss = loss_fct(
+ shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
+ )
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
@@ -355,11 +362,13 @@ def bloom_for_causal_lm_prepare_inputs_for_generation(
else:
model_inputs = {"input_ids": input_ids}
- model_inputs.update({
- "past_key_values": past_key_values,
- "use_cache": kwargs.get("use_cache"),
- "attention_mask": attention_mask,
- })
+ model_inputs.update(
+ {
+ "past_key_values": past_key_values,
+ "use_cache": kwargs.get("use_cache"),
+ "attention_mask": attention_mask,
+ }
+ )
return model_inputs
@staticmethod
@@ -418,7 +427,7 @@ def bloom_block_forward(
else:
outputs = (output,) + outputs[1:]
- return outputs # hidden_states, present, attentions
+ return outputs # hidden_states, present, attentions
@staticmethod
def bloom_attention_forward(
@@ -433,20 +442,19 @@ def bloom_attention_forward(
output_attentions: bool = False,
infer_state: Optional[BatchInferState] = None,
):
-
- fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
+ fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
batch_size, q_length, H, D_HEAD = query_layer.shape
- k = key_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1
- v = value_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1
+ k = key_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1
+ v = value_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1
mem_manager = infer_state.cache_manager
layer_id = infer_state.decode_layer_id
- if layer_id == 0: # once per model.forward
- infer_state.cache_manager.past_key_values_length += q_length # += 1
+ if layer_id == 0: # once per model.forward
+ infer_state.cache_manager.past_key_values_length += q_length # += 1
if infer_state.is_context_stage:
# context process
@@ -473,9 +481,11 @@ def bloom_attention_forward(
if infer_state.decode_is_contiguous:
# if decode is contiguous, then we copy to key cache and value cache in cache manager directly
cache_k = infer_state.cache_manager.key_buffer[layer_id][
- infer_state.decode_mem_start:infer_state.decode_mem_end, :, :]
+ infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
+ ]
cache_v = infer_state.cache_manager.value_buffer[layer_id][
- infer_state.decode_mem_start:infer_state.decode_mem_end, :, :]
+ infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
+ ]
cache_k.copy_(k)
cache_v.copy_(v)
else:
@@ -488,8 +498,17 @@ def bloom_attention_forward(
b_loc = infer_state.block_loc
b_seq_len = infer_state.seq_len
output = torch.empty_like(q)
- token_attention_fwd(q, mem_manager.key_buffer[layer_id], mem_manager.value_buffer[layer_id], output, b_loc,
- b_start_loc, b_seq_len, infer_state.cache_manager.past_key_values_length, alibi)
+ token_attention_fwd(
+ q,
+ mem_manager.key_buffer[layer_id],
+ mem_manager.value_buffer[layer_id],
+ output,
+ b_loc,
+ b_start_loc,
+ b_seq_len,
+ infer_state.cache_manager.past_key_values_length,
+ alibi,
+ )
context_layer = output.view(batch_size, q_length, H * D_HEAD)
@@ -506,8 +525,8 @@ def bloom_attention_forward(
output_tensor = torch.zeros_like(context_layer)
for i in range(self.pretraining_tp):
output_tensor = output_tensor + F.linear(
- context_layer[:, :, int(i * slices):int((i + 1) * slices)],
- self.dense.weight[:, int(i * slices):int((i + 1) * slices)],
+ context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
+ self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
)
else:
output_tensor = self.dense(context_layer)
diff --git a/colossalai/inference/tensor_parallel/modeling/chatglm2.py b/colossalai/inference/tensor_parallel/modeling/chatglm2.py
new file mode 100644
index 000000000000..4b1bc601f436
--- /dev/null
+++ b/colossalai/inference/tensor_parallel/modeling/chatglm2.py
@@ -0,0 +1,540 @@
+import os
+from typing import Optional, Tuple
+
+import torch
+from torch.nn import CrossEntropyLoss
+from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+
+from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
+from colossalai.kernel.triton.context_attention import llama2_context_attn_fwd
+from colossalai.kernel.triton.rotary_embedding_kernel import Llama2Forwards
+from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards
+from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
+ ChatGLMForConditionalGeneration,
+ ChatGLMModel,
+ GLMBlock,
+ GLMTransformer,
+ SelfAttention,
+ split_tensor_along_last_dim,
+)
+
+from ._utils import copy_kv_to_mem_cache
+
+
+# This func is same as Llama model init_to_get_rotary, we should move them into _utils.py
+def _init_to_get_rotary(self, base=10000):
+ self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads
+ if not hasattr(self.config, "rope_scaling"):
+ rope_scaling_factor = 1.0
+ else:
+ rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0
+ if hasattr(self.config, "max_sequence_length"):
+ max_seq_len = self.config.max_sequence_length
+ elif hasattr(self.config, "max_position_embeddings"):
+ max_seq_len = self.config.max_position_embeddings * rope_scaling_factor
+ else:
+ max_seq_len = 2048 * rope_scaling_factor
+ base = float(base)
+
+ # NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
+ try:
+ ntk_alpha = float(os.environ.get("INFER_NTK_ALPHA", 1))
+ assert ntk_alpha >= 1
+ if ntk_alpha > 1:
+ print(f"Note: NTK enabled, alpha set to {ntk_alpha}")
+ max_seq_len *= ntk_alpha
+ base = base * (ntk_alpha ** (self.head_dim_ / (self.head_dim_ - 2))) # Base change formula
+ except:
+ pass
+ n_elem = self.config.head_dim_ // 2
+ inv_freq = 1.0 / (base ** (torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem))
+ t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
+ freqs = torch.outer(t, inv_freq)
+
+ self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
+ self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
+ return
+
+
+def get_masks(self, input_ids, past_length, padding_mask=None):
+ batch_size, seq_length = input_ids.shape
+ full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
+ full_attention_mask.tril_()
+ if past_length:
+ full_attention_mask = torch.cat(
+ (
+ torch.ones(batch_size, seq_length, past_length, device=input_ids.device),
+ full_attention_mask,
+ ),
+ dim=-1,
+ )
+
+ if padding_mask is not None:
+ full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
+ if not past_length and padding_mask is not None:
+ full_attention_mask -= padding_mask.unsqueeze(-1) - 1
+ full_attention_mask = (full_attention_mask < 0.5).bool()
+ full_attention_mask.unsqueeze_(1)
+ return full_attention_mask
+
+
+class ChatGLM2InferenceForwards:
+ """
+ This class holds forwards for Chatglm2 inference.
+ We intend to replace the forward methods for ChatGLMModel, ChatGLMEecoderLayer, and ChatGLMAttention.
+ """
+
+ @staticmethod
+ def chatglm_for_conditional_generation_forward(
+ self: ChatGLMForConditionalGeneration,
+ input_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ return_last_logit: Optional[bool] = False,
+ ):
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ infer_state = self.infer_state
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ past_key_values_length = 0
+
+ # NOT READY FOR PRIME TIME
+ # dummy but work, revise it
+ past_key_values_length = infer_state.cache_manager.past_key_values_length
+ seq_length_with_past = seq_length + past_key_values_length
+ infer_state.seq_length_with_past = seq_length_with_past
+
+ # prefill stage at first
+ if use_cache and seq_length != 1:
+ infer_state.is_context_stage = True
+ infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num)
+ infer_state.init_block_loc(
+ infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index
+ )
+ else:
+ infer_state.is_context_stage = False
+ alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size)
+ if alloc_mem is not None:
+ infer_state.decode_is_contiguous = True
+ infer_state.decode_mem_index = alloc_mem[0]
+ infer_state.decode_mem_start = alloc_mem[1]
+ infer_state.decode_mem_end = alloc_mem[2]
+ infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
+ else:
+ print(f" *** Encountered allocation non-contiguous")
+ print(
+ f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}"
+ )
+ infer_state.decode_is_contiguous = False
+ alloc_mem = infer_state.cache_manager.alloc(batch_size)
+ infer_state.decode_mem_index = alloc_mem
+ # infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
+ # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
+ infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
+
+ # related to rotary embedding
+ if infer_state.is_context_stage:
+ infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(
+ position_ids.view(-1).shape[0], -1
+ )
+ infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(
+ position_ids.view(-1).shape[0], -1
+ )
+ else:
+ seq_len = infer_state.seq_len
+ infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
+ infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
+ infer_state.other_kv_index = infer_state.block_loc[0, infer_state.max_len_in_batch - 1].item()
+
+ transformer_outputs = self.transformer(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ infer_state=infer_state,
+ )
+
+ hidden_states = transformer_outputs[0]
+ if return_last_logit:
+ hidden_states = hidden_states[-1:]
+ lm_logits = self.transformer.output_layer(hidden_states)
+ lm_logits = lm_logits.transpose(0, 1).contiguous()
+
+ loss = None
+ if labels is not None:
+ lm_logits = lm_logits.to(torch.float32)
+
+ # Shift so that tokens < n predict n
+ shift_logits = lm_logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
+
+ lm_logits = lm_logits.to(hidden_states.dtype)
+ loss = loss.to(hidden_states.dtype)
+
+ if not return_dict:
+ output = (lm_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=lm_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+ @staticmethod
+ def chatglm_model_forward(
+ self: ChatGLMModel,
+ input_ids,
+ position_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.BoolTensor] = None,
+ full_attention_mask: Optional[torch.BoolTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ infer_state: BatchInferState = None,
+ ):
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ batch_size, seq_length = input_ids.shape
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embedding(input_ids)
+
+ if self.pre_seq_len is not None:
+ if past_key_values is None:
+ past_key_values = self.get_prompt(
+ batch_size=batch_size,
+ device=input_ids.device,
+ dtype=inputs_embeds.dtype,
+ )
+ if attention_mask is not None:
+ attention_mask = torch.cat(
+ [
+ attention_mask.new_ones((batch_size, self.pre_seq_len)),
+ attention_mask,
+ ],
+ dim=-1,
+ )
+ if full_attention_mask is None:
+ if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
+ full_attention_mask = get_masks(
+ self, input_ids, infer_state.cache_manager.past_key_values_length, padding_mask=attention_mask
+ )
+
+ # Run encoder.
+ hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
+ inputs_embeds,
+ full_attention_mask,
+ kv_caches=past_key_values,
+ use_cache=use_cache,
+ output_hidden_states=output_hidden_states,
+ infer_state=infer_state,
+ )
+
+ # update indices
+ # infer_state.block_loc[:, infer_state.max_len_in_batch-1] = infer_state.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
+ infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
+ infer_state.seq_len += 1
+ infer_state.max_len_in_batch += 1
+ infer_state.cache_manager.past_key_values_length += seq_length
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ presents,
+ all_hidden_states,
+ all_self_attentions,
+ ]
+ if v is not None
+ )
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=presents,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+ @staticmethod
+ def chatglm_encoder_forward(
+ self: GLMTransformer,
+ hidden_states,
+ attention_mask,
+ kv_caches=None,
+ use_cache: Optional[bool] = True,
+ output_hidden_states: Optional[bool] = False,
+ infer_state: Optional[BatchInferState] = None,
+ ):
+ hidden_states = hidden_states.transpose(0, 1).contiguous()
+ if not kv_caches:
+ kv_caches = [None for _ in range(self.num_layers)]
+ presents = () if use_cache else None
+ all_self_attentions = None
+ all_hidden_states = () if output_hidden_states else None
+
+ infer_state.decode_layer_id = 0
+ for index in range(self.num_layers):
+ layer = self.layers[index]
+
+ layer_ret = layer(
+ hidden_states,
+ attention_mask,
+ kv_cache=kv_caches[index],
+ use_cache=use_cache,
+ infer_state=infer_state,
+ )
+
+ infer_state.decode_layer_id += 1
+
+ hidden_states, kv_cache = layer_ret
+ if use_cache:
+ presents = presents + (kv_cache,)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ # Final layer norm.
+ hidden_states = hidden_states.transpose(0, 1).contiguous()
+
+ if self.post_layer_norm:
+ hidden_states = self.final_layernorm(hidden_states)
+
+ return hidden_states, presents, all_hidden_states, all_self_attentions
+
+ @staticmethod
+ def chatglm_glmblock_forward(
+ self: GLMBlock,
+ hidden_states,
+ attention_mask,
+ kv_cache=None,
+ use_cache=True,
+ infer_state: Optional[BatchInferState] = None,
+ ):
+ # hidden_states: [s, b, h]
+
+ # Layer norm at the beginning of the transformer layer.
+ layernorm_output = self.input_layernorm(hidden_states)
+ # Self attention.
+ attention_output, kv_cache = self.self_attention(
+ layernorm_output,
+ attention_mask,
+ kv_cache=kv_cache,
+ use_cache=use_cache,
+ infer_state=infer_state,
+ )
+ # Residual connection.
+ if self.apply_residual_connection_post_layernorm:
+ residual = layernorm_output
+ else:
+ residual = hidden_states
+ layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
+ layernorm_input = residual + layernorm_input
+ # Layer norm post the self attention.
+ layernorm_output = self.post_attention_layernorm(layernorm_input)
+ # MLP.
+ mlp_output = self.mlp(layernorm_output)
+
+ # Second residual connection.
+ if self.apply_residual_connection_post_layernorm:
+ residual = layernorm_output
+ else:
+ residual = layernorm_input
+
+ output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
+ output = residual + output
+ return output, kv_cache
+
+ @staticmethod
+ def chatglm_flash_attn_kvcache_forward(
+ self: SelfAttention,
+ hidden_states,
+ attention_mask,
+ kv_cache=None,
+ use_cache=True,
+ infer_state: Optional[BatchInferState] = None,
+ ):
+ assert use_cache is True, "use_cache should be set to True using this chatglm attention"
+ # hidden_states: original :[sq, b, h] --> this [b, sq, h]
+ batch_size = hidden_states.shape[0]
+ # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
+ mixed_x_layer = self.query_key_value(hidden_states)
+
+ if self.multi_query_attention:
+ (query_layer, key_layer, value_layer) = mixed_x_layer.split(
+ [
+ self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
+ self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
+ self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
+ ],
+ dim=-1,
+ )
+ query_layer = query_layer.view(
+ query_layer.size()[:-1]
+ + (
+ self.num_attention_heads_per_partition,
+ self.hidden_size_per_attention_head,
+ )
+ )
+ key_layer = key_layer.view(
+ key_layer.size()[:-1]
+ + (
+ self.num_multi_query_groups_per_partition,
+ self.hidden_size_per_attention_head,
+ )
+ )
+ value_layer = value_layer.view(
+ value_layer.size()[:-1]
+ + (
+ self.num_multi_query_groups_per_partition,
+ self.hidden_size_per_attention_head,
+ )
+ )
+
+ else:
+ new_tensor_shape = mixed_x_layer.size()[:-1] + (
+ self.num_attention_heads_per_partition,
+ 3 * self.hidden_size_per_attention_head,
+ )
+ mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
+ # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
+ (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
+
+ cos, sin = infer_state.position_cos, infer_state.position_sin
+
+ Llama2Forwards.rotary_emb_fwd(
+ query_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), cos, sin
+ )
+ if self.multi_query_attention:
+ Llama2Forwards.rotary_emb_fwd(
+ key_layer.view(-1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head),
+ cos,
+ sin,
+ )
+ else:
+ Llama2Forwards.rotary_emb_fwd(
+ key_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head),
+ cos,
+ sin,
+ )
+
+ # reshape q k v to [bsz*sql, num_heads, head_dim] 2*1 ,32/2 ,128
+ query_layer = query_layer.reshape(
+ -1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head
+ )
+ key_layer = key_layer.reshape(
+ -1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head
+ )
+ value_layer = value_layer.reshape(
+ -1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head
+ )
+ if infer_state.is_context_stage:
+ # first token generation:
+ # copy key and value calculated in current step to memory manager
+
+ copy_kv_to_mem_cache(
+ infer_state.decode_layer_id,
+ key_layer,
+ value_layer,
+ infer_state.context_mem_index,
+ infer_state.cache_manager,
+ )
+
+ attn_output = torch.empty_like(query_layer.view(-1, self.projection_size))
+
+ # NOTE: no bug in context attn fwd (del it )
+ llama2_context_attn_fwd(
+ query_layer,
+ key_layer,
+ value_layer,
+ attn_output.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head),
+ infer_state.start_loc,
+ infer_state.seq_len,
+ infer_state.seq_length_with_past,
+ )
+
+ else:
+ if infer_state.decode_is_contiguous:
+ # if decode is contiguous, then we copy to key cache and value cache in cache manager directly
+ cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][
+ infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
+ ]
+ cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][
+ infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
+ ]
+ cache_k.copy_(key_layer)
+ cache_v.copy_(value_layer)
+ else:
+ # if decode is not contiguous, use triton kernel to copy key and value cache
+ # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head
+ copy_kv_to_mem_cache(
+ infer_state.decode_layer_id,
+ key_layer,
+ value_layer,
+ infer_state.decode_mem_index,
+ infer_state.cache_manager,
+ )
+
+ # second token and follows
+ attn_output = torch.empty_like(query_layer.view(-1, self.projection_size))
+ cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][
+ : infer_state.decode_mem_end, :, :
+ ]
+ cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][
+ : infer_state.decode_mem_end, :, :
+ ]
+
+ # ==================================
+ # core attention computation is replaced by triton kernel
+ # ==================================
+ Llama2TokenAttentionForwards.token_attn(
+ query_layer,
+ cache_k,
+ cache_v,
+ attn_output,
+ infer_state.block_loc,
+ infer_state.start_loc,
+ infer_state.seq_len,
+ infer_state.max_len_in_batch,
+ infer_state.other_kv_index,
+ )
+
+ # print('after attention',torch.isnan(attn_output).any())
+
+ # =================
+ # Output:[b,sq, h]
+ # =================
+
+ output = self.dense(attn_output).reshape(batch_size, -1, self.projection_size)
+ return output, kv_cache
diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py
index 219cd1ae0d0e..64d6e947e924 100644
--- a/colossalai/inference/tensor_parallel/modeling/llama.py
+++ b/colossalai/inference/tensor_parallel/modeling/llama.py
@@ -1,18 +1,20 @@
from typing import List, Optional, Tuple
-import numpy as np
import torch
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
-from colossalai.kernel.triton.context_attention import llama_context_attn_fwd
-from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest
-from colossalai.kernel.triton.rotary_embedding_kernel import rotary_embedding_fwd
-from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd
+from colossalai.kernel.triton import (
+ copy_kv_cache_to_dest,
+ llama_context_attn_fwd,
+ rotary_embedding_fwd,
+ token_attention_fwd,
+)
try:
from vllm import layernorm_ops, pos_encoding_ops
+
rms_norm = layernorm_ops.rms_norm
rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox
HAS_VLLM_KERNERL = True
@@ -27,17 +29,17 @@
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
- x1 = x[..., :x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2:]
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
- cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
- sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
- cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
- sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
@@ -69,8 +71,7 @@ def llama_model_forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
-
- batch_size = input_ids.shape[0] # input_ids.shape[0]
+ batch_size = input_ids.shape[0] # input_ids.shape[0]
infer_state = self.infer_state
@@ -99,12 +100,13 @@ def llama_model_forward(
# NOTE: differentiate with prefill stage
# block_loc require different value-assigning method for two different stage
if use_cache and seq_length != 1:
- # NOTE assuem prefill stage
+ # NOTE assume prefill stage
# allocate memory block
- infer_state.is_context_stage = True # set prefill stage, notify attention layer
+ infer_state.is_context_stage = True # set prefill stage, notify attention layer
infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num)
- infer_state.init_block_loc(infer_state.block_loc, infer_state.seq_len, seq_length,
- infer_state.context_mem_index)
+ infer_state.init_block_loc(
+ infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index
+ )
else:
infer_state.is_context_stage = False
alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size)
@@ -127,20 +129,20 @@ def llama_model_forward(
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
- position_ids = torch.arange(past_key_values_length,
- seq_length + past_key_values_length,
- dtype=torch.long,
- device=device)
+ position_ids = torch.arange(
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
+ )
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if infer_state.is_context_stage:
-
infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(
- position_ids.view(-1).shape[0], -1)
+ position_ids.view(-1).shape[0], -1
+ )
infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(
- position_ids.view(-1).shape[0], -1)
+ position_ids.view(-1).shape[0], -1
+ )
else:
seq_len = infer_state.seq_len
infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
@@ -151,12 +153,13 @@ def llama_model_forward(
# embed positions
if attention_mask is None:
- attention_mask = torch.ones((batch_size, seq_length_with_past),
- dtype=torch.bool,
- device=inputs_embeds.device)
+ attention_mask = torch.ones(
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
+ )
- attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds,
- past_key_values_length)
+ attention_mask = self._prepare_decoder_attention_mask(
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
+ )
hidden_states = inputs_embeds
@@ -214,7 +217,6 @@ def llama_decoder_layer_forward(
use_cache: Optional[bool] = False,
infer_state: Optional[BatchInferState] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
-
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
@@ -259,7 +261,6 @@ def llama_flash_attn_kvcache_forward(
use_cache: bool = False,
infer_state: Optional[BatchInferState] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
-
assert use_cache is True, "use_cache should be set to True using this llama attention"
bsz, q_len, _ = hidden_states.size()
@@ -275,8 +276,8 @@ def llama_flash_attn_kvcache_forward(
# NOTE might want to revise
# need some way to record the length of past key values cache
# since we won't return past_key_value_cache right now
- if infer_state.decode_layer_id == 0: # once per model.forward
- infer_state.cache_manager.past_key_values_length += q_len # seq_len
+ if infer_state.decode_layer_id == 0: # once per model.forward
+ infer_state.cache_manager.past_key_values_length += q_len # seq_len
cos, sin = infer_state.position_cos, infer_state.position_sin
# print("shape ", cos.shape, query_states.view(-1, self.num_heads, self.head_dim).shape, )
@@ -297,38 +298,62 @@ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index,
# first token generation
# copy key and value calculated in current step to memory manager
- _copy_kv_to_mem_cache(infer_state.decode_layer_id, key_states, value_states, infer_state.context_mem_index,
- infer_state.cache_manager)
+ _copy_kv_to_mem_cache(
+ infer_state.decode_layer_id,
+ key_states,
+ value_states,
+ infer_state.context_mem_index,
+ infer_state.cache_manager,
+ )
attn_output = torch.empty_like(query_states)
- llama_context_attn_fwd(query_states, key_states, value_states, attn_output, infer_state.start_loc,
- infer_state.seq_len, infer_state.cache_manager.past_key_values_length)
+ llama_context_attn_fwd(
+ query_states,
+ key_states,
+ value_states,
+ attn_output,
+ infer_state.start_loc,
+ infer_state.seq_len,
+ infer_state.cache_manager.past_key_values_length,
+ )
else:
-
if infer_state.decode_is_contiguous:
# if decode is contiguous, then we copy to key cache and value cache in cache manager directly
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][
- infer_state.decode_mem_start:infer_state.decode_mem_end, :, :]
+ infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
+ ]
cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][
- infer_state.decode_mem_start:infer_state.decode_mem_end, :, :]
+ infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
+ ]
cache_k.copy_(key_states)
cache_v.copy_(value_states)
else:
# if decode is not contiguous, use triton kernel to copy key and value cache
# k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head
- _copy_kv_to_mem_cache(infer_state.decode_layer_id, key_states, value_states,
- infer_state.decode_mem_index, infer_state.cache_manager)
+ _copy_kv_to_mem_cache(
+ infer_state.decode_layer_id,
+ key_states,
+ value_states,
+ infer_state.decode_mem_index,
+ infer_state.cache_manager,
+ )
# second token and follows
# kv = torch.stack((key_states, value_states), dim=2)
# (batch_size, seqlen, nheads, headdim)
attn_output = torch.empty_like(query_states)
- token_attention_fwd(query_states, infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
- infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], attn_output,
- infer_state.block_loc, infer_state.start_loc, infer_state.seq_len,
- infer_state.cache_manager.past_key_values_length)
+ token_attention_fwd(
+ query_states,
+ infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
+ infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
+ attn_output,
+ infer_state.block_loc,
+ infer_state.start_loc,
+ infer_state.seq_len,
+ infer_state.cache_manager.past_key_values_length,
+ )
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
@@ -339,7 +364,6 @@ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index,
def get_llama_vllm_rmsnorm_forward():
-
if HAS_VLLM_KERNERL:
def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
diff --git a/colossalai/inference/tensor_parallel/policies/__init__.py b/colossalai/inference/tensor_parallel/policies/__init__.py
index 48f8db62c32a..776c4e850565 100644
--- a/colossalai/inference/tensor_parallel/policies/__init__.py
+++ b/colossalai/inference/tensor_parallel/policies/__init__.py
@@ -1,4 +1,5 @@
from .bloom import BloomModelInferPolicy
+from .chatglm2 import ChatGLM2InferPolicy
from .llama import LlamaModelInferPolicy
-__all__ = ['BloomModelInferPolicy', 'LlamaModelInferPolicy']
+__all__ = ["BloomModelInferPolicy", "LlamaModelInferPolicy", "ChatGLM2InferPolicy"]
diff --git a/colossalai/inference/tensor_parallel/policies/bloom.py b/colossalai/inference/tensor_parallel/policies/bloom.py
index 63791fe27284..3d6df2097000 100644
--- a/colossalai/inference/tensor_parallel/policies/bloom.py
+++ b/colossalai/inference/tensor_parallel/policies/bloom.py
@@ -3,15 +3,19 @@
import torch
from torch.nn import LayerNorm
+import colossalai.shardformer.layer as col_nn
+from colossalai.shardformer.modeling.bloom import build_bloom_alibi_tensor_fn
+from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
from colossalai.shardformer.policies.bloom import BloomForCausalLMPolicy
from ..modeling.bloom import BloomInferenceForwards
try:
- from colossalai.kernel.triton.fused_layernorm import layer_norm
+ from colossalai.kernel.triton import layer_norm
+
HAS_TRITON_NORM = True
except:
- print("you should install triton from https://github.com/openai/triton")
+ print("Some of our kernels require triton. You might want to install triton from https://github.com/openai/triton")
HAS_TRITON_NORM = False
@@ -27,40 +31,69 @@ def _triton_layernorm_forward(self: LayerNorm, hidden_states: torch.Tensor):
class BloomModelInferPolicy(BloomForCausalLMPolicy):
-
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomForCausalLM, BloomModel
+
policy = super().module_policy()
+ if self.shard_config.inference_gptq:
+ from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
+ policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={
+ "self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
+ "self_attention.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
+ "self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size,
+ },
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="self_attention.query_key_value",
+ target_module=ColCaiQuantLinear,
+ kwargs={'split_num': 3}),
+ SubModuleReplacementDescription(
+ suffix="self_attention.dense",
+ target_module=RowCaiQuantLinear,
+ kwargs={'split_num': 1}),
+ SubModuleReplacementDescription(
+ suffix="self_attention.attention_dropout",
+ target_module=col_nn.DropoutForParallelInput,
+ ),
+ SubModuleReplacementDescription(
+ suffix="mlp.dense_h_to_4h",
+ target_module=ColCaiQuantLinear,
+ kwargs={'split_num': 1}),
+ SubModuleReplacementDescription(
+ suffix="mlp.dense_4h_to_h",
+ target_module=RowCaiQuantLinear,
+ kwargs={'split_num': 1}),
+ ])
# NOTE set inference mode to shard config
self.shard_config._infer()
method_replacement = {
- 'forward': BloomInferenceForwards.bloom_for_causal_lm_forward,
- 'prepare_inputs_for_generation': BloomInferenceForwards.bloom_for_causal_lm_prepare_inputs_for_generation
+ "forward": BloomInferenceForwards.bloom_for_causal_lm_forward,
+ "prepare_inputs_for_generation": BloomInferenceForwards.bloom_for_causal_lm_prepare_inputs_for_generation,
}
- self.append_or_create_method_replacement(description=method_replacement,
- policy=policy,
- target_key=BloomForCausalLM)
+ self.append_or_create_method_replacement(
+ description=method_replacement, policy=policy, target_key=BloomForCausalLM
+ )
- method_replacement = {'forward': BloomInferenceForwards.bloom_model_forward}
+ method_replacement = {"forward": BloomInferenceForwards.bloom_model_forward}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=BloomModel)
- method_replacement = {'forward': BloomInferenceForwards.bloom_block_forward}
+ method_replacement = {"forward": BloomInferenceForwards.bloom_block_forward}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=BloomBlock)
- method_replacement = {'forward': BloomInferenceForwards.bloom_attention_forward}
- self.append_or_create_method_replacement(description=method_replacement,
- policy=policy,
- target_key=BloomAttention)
+ method_replacement = {"forward": BloomInferenceForwards.bloom_attention_forward}
+ self.append_or_create_method_replacement(
+ description=method_replacement, policy=policy, target_key=BloomAttention
+ )
if HAS_TRITON_NORM:
infer_method = get_triton_layernorm_forward()
- method_replacement = {'forward': partial(infer_method)}
- self.append_or_create_method_replacement(description=method_replacement,
- policy=policy,
- target_key=LayerNorm)
+ method_replacement = {"forward": partial(infer_method)}
+ self.append_or_create_method_replacement(
+ description=method_replacement, policy=policy, target_key=LayerNorm
+ )
return policy
diff --git a/colossalai/inference/tensor_parallel/policies/chatglm2.py b/colossalai/inference/tensor_parallel/policies/chatglm2.py
new file mode 100644
index 000000000000..cb223370a65d
--- /dev/null
+++ b/colossalai/inference/tensor_parallel/policies/chatglm2.py
@@ -0,0 +1,77 @@
+from functools import partial
+
+import torch
+
+from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
+ ChatGLMForConditionalGeneration,
+ ChatGLMModel,
+ GLMBlock,
+ GLMTransformer,
+ SelfAttention,
+)
+# import colossalai
+from colossalai.shardformer.policies.chatglm2 import ChatGLMModelPolicy
+
+from ..modeling.chatglm2 import ChatGLM2InferenceForwards, _init_to_get_rotary
+
+try:
+ from colossalai.kernel.triton.rms_norm import rmsnorm_forward
+ HAS_TRITON_RMSNORM = True
+except:
+ print("you should install triton from https://github.com/openai/triton")
+ HAS_TRITON_RMSNORM = False
+
+
+class ChatGLM2InferPolicy(ChatGLMModelPolicy):
+
+ def __init__(self) -> None:
+ super().__init__()
+
+ def module_policy(self):
+ policy = super().module_policy()
+ self.shard_config._infer()
+
+ model_infer_forward = ChatGLM2InferenceForwards.chatglm_model_forward
+ method_replacement = {'forward': model_infer_forward}
+ self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=ChatGLMModel)
+
+ encoder_infer_forward = ChatGLM2InferenceForwards.chatglm_encoder_forward
+ method_replacement = {'forward': encoder_infer_forward}
+ self.append_or_create_method_replacement(description=method_replacement,
+ policy=policy,
+ target_key=GLMTransformer)
+
+ encoder_layer_infer_forward = ChatGLM2InferenceForwards.chatglm_glmblock_forward
+ method_replacement = {'forward': encoder_layer_infer_forward}
+ self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=GLMBlock)
+
+ attn_infer_forward = ChatGLM2InferenceForwards.chatglm_flash_attn_kvcache_forward
+ method_replacement = {'forward': attn_infer_forward}
+ self.append_or_create_method_replacement(description=method_replacement,
+ policy=policy,
+ target_key=SelfAttention)
+
+ # for rmsnorm and others, we need to check the shape
+ return policy
+
+ def postprocess(self):
+ _init_to_get_rotary(self.model)
+ return self.model
+
+
+class ChatGLM2ForConditionalGenerationInferPolicy(ChatGLM2InferPolicy):
+
+ def __init__(self) -> None:
+ super().__init__()
+
+ def module_policy(self):
+ policy = super().module_policy()
+ model_infer_forward = ChatGLM2InferenceForwards.chatglm_for_conditional_generation_forward
+ method_replacement = {'forward': partial(model_infer_forward)}
+ self.append_or_create_method_replacement(description=method_replacement,
+ policy=policy,
+ target_key=ChatGLMForConditionalGeneration)
+ return policy
+
+ def postprocess(self):
+ return super().postprocess()
diff --git a/colossalai/inference/tensor_parallel/policies/llama.py b/colossalai/inference/tensor_parallel/policies/llama.py
index e819f2a8810c..eaaadadd1f88 100644
--- a/colossalai/inference/tensor_parallel/policies/llama.py
+++ b/colossalai/inference/tensor_parallel/policies/llama.py
@@ -1,57 +1,107 @@
from functools import partial
+
import torch
-from transformers.models.llama.modeling_llama import (
- LlamaAttention,
- LlamaDecoderLayer,
- LlamaModel,
- LlamaRMSNorm
-)
+from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm
+from colossalai.shardformer.layer import VocabParallelEmbedding1D
+from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
# import colossalai
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
+
from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward
try:
- from colossalai.kernel.triton.rms_norm import rmsnorm_forward
+ from colossalai.kernel.triton import rmsnorm_forward
+
HAS_TRITON_RMSNORM = True
except:
print("you should install triton from https://github.com/openai/triton")
HAS_TRITON_RMSNORM = False
-
+
def get_triton_rmsnorm_forward():
if HAS_TRITON_RMSNORM:
+
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
return rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon)
-
+
return _triton_rmsnorm_forward
else:
return None
-
-class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
+
+class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
policy = super().module_policy()
+
+ if self.shard_config.inference_gptq:
+ from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
+
+ decoder_attribute_replacement = {
+ "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
+ "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
+ }
+ policy[LlamaDecoderLayer] = ModulePolicyDescription(
+ attribute_replacement=decoder_attribute_replacement,
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="self_attn.q_proj",
+ target_module=ColCaiQuantLinear,
+ kwargs={'split_num': 1},
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attn.k_proj",
+ target_module=ColCaiQuantLinear,
+ kwargs={'split_num': 1},
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attn.v_proj",
+ target_module=ColCaiQuantLinear,
+ kwargs={'split_num': 1},
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attn.o_proj",
+ target_module=RowCaiQuantLinear,
+ kwargs={'split_num': 1},
+ ),
+ SubModuleReplacementDescription(
+ suffix="mlp.gate_proj",
+ target_module=ColCaiQuantLinear,
+ kwargs={'split_num': 1},
+ ),
+ SubModuleReplacementDescription(
+ suffix="mlp.up_proj",
+ target_module=ColCaiQuantLinear,
+ kwargs={'split_num': 1},
+ ),
+ SubModuleReplacementDescription(
+ suffix="mlp.down_proj",
+ target_module=RowCaiQuantLinear,
+ kwargs={'split_num': 1},
+ )
+ ],
+ )
+
self.shard_config._infer()
infer_forward = LlamaInferenceForwards.llama_model_forward
- method_replacement = {'forward': partial(infer_forward)}
+ method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel)
infer_forward = LlamaInferenceForwards.llama_decoder_layer_forward
- method_replacement = {'forward': partial(infer_forward)}
- self.append_or_create_method_replacement(description=method_replacement,
- policy=policy,
- target_key=LlamaDecoderLayer)
+ method_replacement = {"forward": partial(infer_forward)}
+ self.append_or_create_method_replacement(
+ description=method_replacement, policy=policy, target_key=LlamaDecoderLayer
+ )
infer_forward = LlamaInferenceForwards.llama_flash_attn_kvcache_forward
- method_replacement = {'forward': partial(infer_forward)}
- self.append_or_create_method_replacement(description=method_replacement,
- policy=policy,
- target_key=LlamaAttention)
+ method_replacement = {"forward": partial(infer_forward)}
+ self.append_or_create_method_replacement(
+ description=method_replacement, policy=policy, target_key=LlamaAttention
+ )
infer_forward = None
if HAS_TRITON_RMSNORM:
@@ -59,12 +109,11 @@ def module_policy(self):
else:
# NOTE: adding rms_norm from cuda kernels caused precision issue, fix @tiandiao123
infer_forward = get_llama_vllm_rmsnorm_forward()
-
+
if infer_forward is not None:
- method_replacement = {'forward': partial(infer_forward)}
- self.append_or_create_method_replacement(description=method_replacement,
- policy=policy,
- target_key=LlamaRMSNorm)
+ method_replacement = {"forward": partial(infer_forward)}
+ self.append_or_create_method_replacement(
+ description=method_replacement, policy=policy, target_key=LlamaRMSNorm
+ )
return policy
-
diff --git a/colossalai/initialize.py b/colossalai/initialize.py
index a1694e059fb4..aac57d34a2c1 100644
--- a/colossalai/initialize.py
+++ b/colossalai/initialize.py
@@ -1,69 +1,30 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
-import argparse
import os
-import pprint
+import warnings
from pathlib import Path
-from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
+from typing import Dict, Union
import torch
-import torch.nn as nn
-from torch.nn.modules.loss import _Loss
-from torch.nn.parallel import DistributedDataParallel as DDP
-from torch.optim.lr_scheduler import _LRScheduler
-from torch.optim.optimizer import Optimizer
-from torch.utils.data import DataLoader
+import torch.distributed as dist
-from colossalai.amp import AMP_TYPE, convert_to_amp
-from colossalai.amp.naive_amp import NaiveAMPModel
-from colossalai.context import Config, ConfigException, ParallelMode
-from colossalai.context.moe_context import MOE_CONTEXT
-from colossalai.core import global_context as gpc
-from colossalai.legacy.builder.builder import build_gradient_handler
-from colossalai.legacy.engine import Engine
-from colossalai.legacy.engine.gradient_accumulation import accumulate_gradient
-from colossalai.legacy.engine.schedule import (
- InterleavedPipelineSchedule,
- NonPipelineSchedule,
- PipelineSchedule,
- get_tensor_shape,
-)
+from colossalai.context import Config
from colossalai.logging import get_dist_logger
-from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
-from colossalai.utils import get_current_device, is_using_ddp, is_using_pp, is_using_sequence, sync_model_param
-from colossalai.utils.moe import sync_moe_model_param
-from colossalai.zero.legacy import ShardedOptimizerV2, convert_to_zero_v2
-from colossalai.zero.legacy.gemini.ophooks import BaseOpHook
-
-
-def get_default_parser():
- """Reads user command line and uses an argument parser to parse the input arguments.
- Input arguments include configuration, host, port, world size, local rank, backend for torch.distributed.
-
- Returns:
- Namespace: Returns the parser with the default arguments, the user may add customized arguments into this parser.
- """
- parser = argparse.ArgumentParser()
- parser.add_argument('--config', type=str, help='path to the config file')
- parser.add_argument('--host', type=str, help='the master address for distributed training')
- parser.add_argument('--port', type=int, help='the master port for distributed training')
- parser.add_argument('--world_size', type=int, help='world size for distributed training')
- parser.add_argument('--rank', type=int, help='rank for the default process group')
- parser.add_argument('--local_rank', type=int, help='local rank on the node')
- parser.add_argument('--backend', type=str, default='nccl', help='backend for distributed communication')
- return parser
-
-
-def launch(config: Union[str, Path, Config, Dict],
- rank: int,
- world_size: int,
- host: str,
- port: int,
- backend: str = 'nccl',
- local_rank: int = None,
- seed: int = 1024,
- verbose: bool = True):
+from colossalai.utils import set_device, set_seed
+
+
+def launch(
+ config: Union[str, Path, Config, Dict],
+ rank: int,
+ world_size: int,
+ host: str,
+ port: int,
+ backend: str = "nccl",
+ local_rank: int = None,
+ seed: int = 1024,
+ verbose: bool = True,
+):
"""This function first parses the configuration arguments, using :func:`parse_args()` in case one of the input
arguments are not given. Then initialize and set distributed environment by calling global_context's functions.
@@ -83,48 +44,33 @@ def launch(config: Union[str, Path, Config, Dict],
Raises:
Exception: Raise exception when config type is wrong
"""
- gpc.verbose = verbose
-
- # set config
- assert isinstance(config, (Config, str, Path, dict)), \
- f'expected argument config to be Config, str or Path, but got {type(config)}'
- if not isinstance(config, Config) and isinstance(config, dict):
- config = Config(config)
- if isinstance(config, (str, Path)):
- config = Config.from_file(config)
- gpc.load_config(config)
+ if rank == 0:
+ warnings.warn("`config` is deprecated and will be removed soon.")
# init default process group
- gpc.init_global_dist(rank, world_size, backend, host, port)
-
- # init process groups for different parallel modes from config
- gpc.init_parallel_groups()
+ init_method = f"tcp://[{host}]:{port}"
+ dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method)
# set cuda device
if torch.cuda.is_available():
# if local rank is not given, calculate automatically
- gpc.set_device(local_rank)
+ set_device(local_rank)
- # set the number of processes running on the same node
- gpc.detect_num_processes_on_current_node()
-
- gpc.set_seed(seed)
+ set_seed(seed)
if verbose:
logger = get_dist_logger()
- logger.info(
- f'Distributed environment is initialized, '
- f'data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, '
- f'tensor parallel size: {gpc.tensor_parallel_size}',
- ranks=[0])
+ logger.info(f"Distributed environment is initialized, world size: {dist.get_world_size()}", ranks=[0])
-def launch_from_slurm(config: Union[str, Path, Config, Dict],
- host: str,
- port: int,
- backend: str = 'nccl',
- seed: int = 1024,
- verbose: bool = True):
+def launch_from_slurm(
+ config: Union[str, Path, Config, Dict],
+ host: str,
+ port: int,
+ backend: str = "nccl",
+ seed: int = 1024,
+ verbose: bool = True,
+):
"""A wrapper for colossalai.launch for SLURM launcher by reading rank and world size from the environment variables
set by SLURM
@@ -137,29 +83,33 @@ def launch_from_slurm(config: Union[str, Path, Config, Dict],
verbose (bool, optional): Whether to print logs. Defaults to True.
"""
try:
- rank = int(os.environ['SLURM_PROCID'])
- world_size = int(os.environ['SLURM_NPROCS'])
+ rank = int(os.environ["SLURM_PROCID"])
+ world_size = int(os.environ["SLURM_NPROCS"])
except KeyError as e:
raise RuntimeError(
f"Could not find {e} in the SLURM environment, visit https://www.colossalai.org/ for more information on launching with SLURM"
)
- launch(config=config,
- rank=rank,
- world_size=world_size,
- host=host,
- port=port,
- backend=backend,
- seed=seed,
- verbose=verbose)
-
-
-def launch_from_openmpi(config: Union[str, Path, Config, Dict],
- host: str,
- port: int,
- backend: str = 'nccl',
- seed: int = 1024,
- verbose: bool = True):
+ launch(
+ config=config,
+ rank=rank,
+ world_size=world_size,
+ host=host,
+ port=port,
+ backend=backend,
+ seed=seed,
+ verbose=verbose,
+ )
+
+
+def launch_from_openmpi(
+ config: Union[str, Path, Config, Dict],
+ host: str,
+ port: int,
+ backend: str = "nccl",
+ seed: int = 1024,
+ verbose: bool = True,
+):
"""A wrapper for colossalai.launch for OpenMPI launcher by reading rank and world size from the environment variables
set by OpenMPI
@@ -172,29 +122,30 @@ def launch_from_openmpi(config: Union[str, Path, Config, Dict],
verbose (bool, optional): Whether to print logs. Defaults to True.
"""
try:
- rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
- local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
- world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
+ rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
+ local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
+ world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])
except KeyError as e:
raise RuntimeError(
f"Could not find {e} in the OpenMPI environment, visit https://www.colossalai.org/ for more information on launching with OpenMPI"
)
- launch(config=config,
- local_rank=local_rank,
- rank=rank,
- world_size=world_size,
- host=host,
- port=port,
- backend=backend,
- seed=seed,
- verbose=verbose)
-
-
-def launch_from_torch(config: Union[str, Path, Config, Dict],
- backend: str = 'nccl',
- seed: int = 1024,
- verbose: bool = True):
+ launch(
+ config=config,
+ local_rank=local_rank,
+ rank=rank,
+ world_size=world_size,
+ host=host,
+ port=port,
+ backend=backend,
+ seed=seed,
+ verbose=verbose,
+ )
+
+
+def launch_from_torch(
+ config: Union[str, Path, Config, Dict], backend: str = "nccl", seed: int = 1024, verbose: bool = True
+):
"""A wrapper for colossalai.launch for torchrun or torch.distributed.launch by reading rank and world size
from the environment variables set by PyTorch
@@ -205,266 +156,24 @@ def launch_from_torch(config: Union[str, Path, Config, Dict],
verbose (bool, optional): Whether to print logs. Defaults to True.
"""
try:
- rank = int(os.environ['RANK'])
- local_rank = int(os.environ['LOCAL_RANK'])
- world_size = int(os.environ['WORLD_SIZE'])
- host = os.environ['MASTER_ADDR']
- port = int(os.environ['MASTER_PORT'])
+ rank = int(os.environ["RANK"])
+ local_rank = int(os.environ["LOCAL_RANK"])
+ world_size = int(os.environ["WORLD_SIZE"])
+ host = os.environ["MASTER_ADDR"]
+ port = int(os.environ["MASTER_PORT"])
except KeyError as e:
raise RuntimeError(
f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch"
)
- launch(config=config,
- local_rank=local_rank,
- rank=rank,
- world_size=world_size,
- host=host,
- port=port,
- backend=backend,
- seed=seed,
- verbose=verbose)
-
-
-def initialize(model: nn.Module,
- optimizer: Optimizer,
- criterion: Optional[_Loss] = None,
- train_dataloader: Optional[Iterable] = None,
- test_dataloader: Optional[Iterable] = None,
- lr_scheduler: Optional[_LRScheduler] = None,
- ophooks: Optional[List[BaseOpHook]] = None,
- verbose: bool = True) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]:
- """Core function to wrap the essential training components with our functionality based on the config which is
- loaded into gpc.config.
-
- Args:
- model (:class:`torch.nn.Module` or Callable): Your model instance or a function to build the model.
- optimizer (:class:`torch.optim.optimizer.Optimizer` or :class:`Type[torch.optim.optimizer]`):
- Your optimizer instance.
- criterion (:class:`torch.nn.modules.loss._Loss`, optional): Your criterion instance.
- train_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for training.
- test_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for testing.
- lr_scheduler (:class:`torch.nn.lr_scheduler._LRScheduler`, optional): Your lr scheduler instance, optional.
- verbose (bool, optional): Whether to print logs.
-
- Returns:
- Tuple (engine, train_dataloader, test_dataloader, lr_scheduler):
- A tuple of ``(engine, train_dataloader, test_dataloader, lr_scheduler)``
- where only ``engine`` could not be None.
- """
- # get logger
- logger = get_dist_logger()
- gpc.verbose = verbose
-
- # get config from gpc
- config = gpc.config
-
- # print config
- if verbose:
- logger.info(
- f"\n========== Your Config ========\n"
- f"{pprint.pformat(gpc.config)}\n"
- f"================================\n",
- ranks=[0])
-
- # cudnn
- cudnn_benchmark = config.get('cudnn_benchmark', False)
- cudnn_deterministic = config.get('cudnn_deterministic', False)
- torch.backends.cudnn.benchmark = cudnn_benchmark
- torch.backends.cudnn.deterministic = cudnn_deterministic
- if verbose:
- logger.info(f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0])
-
- # zero
- use_zero = hasattr(gpc.config, 'zero')
- if use_zero:
- zero_cfg = gpc.config.get('zero', None)
- if zero_cfg is not None:
- cfg_ = zero_cfg.copy()
- else:
- cfg_ = {}
- optimizer_config = zero_cfg.get('optimizer_config', None)
- model_config = zero_cfg.get('model_config', None)
- model, optimizer = convert_to_zero_v2(model,
- optimizer,
- model_config=model_config,
- optimizer_config=optimizer_config)
-
- logger.info("Initializing ZeRO model and optimizer finished!", ranks=[0])
- else:
- if isinstance(model, nn.Module):
- # first sync model across dp ranks
- model.to(get_current_device())
- elif isinstance(model, Callable):
- model = model().to(get_current_device())
-
- # optimizer maybe a optimizer_cls
- if isinstance(optimizer, Callable):
- optimizer = optimizer(model.parameters())
- logger.warning("Initializing an non ZeRO model with optimizer class")
-
- if not use_zero:
- if is_using_sequence():
- sync_model_param(model, ParallelMode.SEQUENCE_DP)
- elif MOE_CONTEXT.is_initialized:
- sync_moe_model_param(model)
- elif is_using_ddp():
- sync_model_param(model, ParallelMode.DATA)
- else:
- logger.warning(
- "The parameters of models is not automatically synchronized.\n"
- "Please make sure that all parameters are the same in data parallel group.",
- ranks=[0])
-
- # check amp and zero
- fp16_cfg = gpc.config.get('fp16', None)
-
- if fp16_cfg is not None and fp16_cfg.mode is not None and use_zero:
- raise ConfigException(
- "It is not allowed to set fp16 and zero configuration in your config file at the same time")
-
- # clip grad norm
- clip_grad_norm = gpc.config.get('clip_grad_norm', 0.0)
-
- # initialize amp
- amp_mode = None
- if fp16_cfg is not None and fp16_cfg.mode is not None:
- cfg_ = fp16_cfg.copy()
- amp_mode = cfg_.pop('mode')
- if is_using_pp():
- assert amp_mode == AMP_TYPE.NAIVE, 'Pipeline only support NaiveAMP currently'
- if amp_mode == AMP_TYPE.NAIVE:
- cfg_['clip_grad_norm'] = clip_grad_norm
- model, optimizer, criterion = convert_to_amp(model=model,
- optimizer=optimizer,
- criterion=criterion,
- mode=amp_mode,
- amp_config=cfg_)
-
- # get torch ddp config
- torch_ddp_cfg = gpc.config.get('torch_ddp', dict())
-
- # gradient handler
- gradient_handler_cfg = gpc.config.get('gradient_handler', None)
- if gradient_handler_cfg is None:
- # if gradient handler is not specified in the configuration file,
- # check in the following order
- # 1. if optimizer is ZERO, then use zero grad handler
- # 2. if dp size is larger than 1 and pipeline is not used, use pytorch ddp
- # 3. if using pipeline and dp size larger than 1, use data parallel grad handler
- if isinstance(optimizer, ShardedOptimizerV2):
- gradient_handler_cfg = [dict(type='ZeROGradientHandler')]
- if verbose:
- logger.info(
- "Training with zero is detected, ZeROGradientHandler is automatically "
- "added even though not specified in the configuration",
- ranks=[0])
- elif is_using_ddp() and MOE_CONTEXT.is_initialized:
- gradient_handler_cfg = [dict(type='MoeGradientHandler')]
- if verbose:
- logger.info(
- "Data parallel training is detected with moe parallel, MoeGradientHandler is automatically "
- "added even though not specified in the configuration",
- ranks=[0])
- elif is_using_sequence():
- model = DDP(model,
- process_group=gpc.get_group(ParallelMode.SEQUENCE_DP),
- device_ids=[torch.cuda.current_device()],
- **torch_ddp_cfg)
- if verbose:
- logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism',
- ranks=[0])
- elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE:
- model = DDP(model,
- process_group=gpc.get_group(ParallelMode.DATA),
- device_ids=[torch.cuda.current_device()],
- **torch_ddp_cfg)
- if verbose:
- logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism', ranks=[0])
- elif is_using_ddp():
- gradient_handler_cfg = [dict(type='DataParallelGradientHandler')]
- if verbose:
- logger.info(
- "Data parallel training is detected when using pipeline parallel, "
- "DataParallelGradientHandler is automatically "
- "added even though not specified in the configuration",
- ranks=[0])
- # add pipeline parallel gradient handler, if pipeline shared module is detected
- for param in model.parameters():
- if getattr(param, 'pipeline_shared_module_pg', None) is not None:
- if gradient_handler_cfg is None:
- gradient_handler_cfg = [dict(type='PipelineSharedModuleGradientHandler')]
- else:
- gradient_handler_cfg.append(dict(type='PipelineSharedModuleGradientHandler'))
- if verbose:
- logger.info(
- "pipeline_shared_module is detected, PipelineSharedModuleGradientHandler is automatically "
- "added even though not specified in the configuration",
- ranks=[0])
- break
- else:
- if not isinstance(gradient_handler_cfg, list):
- raise ConfigException(
- f"expected gradient_handler in the configuration file to be a list but got {type(gradient_handler_cfg)}"
- )
-
- # turn off sync buffer for NaiveAMPModel if using torch DDP and NaiveAMPModel at the same time
- # to avoid duplicated buffer synchronization
- if isinstance(model, DDP) and isinstance(model.module, NaiveAMPModel):
- model.module.sync_buffer = False
-
- # initialize schedule for engine
- if is_using_pp():
- tensor_shape = get_tensor_shape()
- use_interleaved = hasattr(gpc.config, 'model') and hasattr(gpc.config.model, 'num_chunks')
- if gpc.is_initialized(ParallelMode.PARALLEL_1D):
- scatter_gather = True
- else:
- scatter_gather = False
- if use_interleaved:
- if isinstance(model, nn.Sequential):
- model = nn.ModuleList([model])
- schedule = InterleavedPipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
- gpc.config.model.num_chunks,
- tensor_shape=tensor_shape,
- scatter_gather_tensors=scatter_gather)
- else:
- schedule = PipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
- tensor_shape=tensor_shape,
- scatter_gather_tensors=scatter_gather)
- else:
- schedule = NonPipelineSchedule()
-
- if gradient_handler_cfg is None:
- gradient_handlers = None
- if verbose and not isinstance(model, DDP):
- logger.warning(
- "No PyTorch DDP or gradient handler is set up, please make sure you do not need "
- "to all-reduce the gradients after a training step.",
- ranks=[0])
- else:
- gradient_handlers = [build_gradient_handler(cfg, model, optimizer) for cfg in gradient_handler_cfg]
-
- # check if optimizer is ColossalaiOptimizer
- if not isinstance(optimizer, (ColossalaiOptimizer, ShardedOptimizerV2)):
- optimizer = ColossalaiOptimizer(optim=optimizer)
-
- # gradient accumulation
- grad_accum_size = gpc.config.get('gradient_accumulation', None)
- if grad_accum_size is not None:
- optimizer, train_dataloader, gradient_handlers, lr_scheduler = accumulate_gradient(
- model=model,
- optimizer=optimizer,
- dataloader=train_dataloader,
- accumulate_size=grad_accum_size,
- gradient_handlers=gradient_handlers,
- lr_scheduler=lr_scheduler)
- engine = Engine(model=model,
- optimizer=optimizer,
- criterion=criterion,
- gradient_handlers=gradient_handlers,
- clip_grad_norm=clip_grad_norm,
- ophook_list=ophooks,
- schedule=schedule)
-
- return engine, train_dataloader, test_dataloader, lr_scheduler
+ launch(
+ config=config,
+ local_rank=local_rank,
+ rank=rank,
+ world_size=world_size,
+ host=host,
+ port=port,
+ backend=backend,
+ seed=seed,
+ verbose=verbose,
+ )
diff --git a/colossalai/interface/__init__.py b/colossalai/interface/__init__.py
index 1c3199fc1aff..98b21c9c02c1 100644
--- a/colossalai/interface/__init__.py
+++ b/colossalai/interface/__init__.py
@@ -1,4 +1,4 @@
from .model import AMPModelMixin, ModelWrapper
from .optimizer import OptimizerWrapper
-__all__ = ['OptimizerWrapper', 'ModelWrapper', 'AMPModelMixin']
+__all__ = ["OptimizerWrapper", "ModelWrapper", "AMPModelMixin"]
diff --git a/colossalai/interface/model.py b/colossalai/interface/model.py
index 7b3d9435d255..58df09b853ee 100644
--- a/colossalai/interface/model.py
+++ b/colossalai/interface/model.py
@@ -26,11 +26,9 @@ def forward(self, *args, **kwargs):
class AMPModelMixin:
- """This mixin class defines the interface for AMP training.
- """
+ """This mixin class defines the interface for AMP training."""
def update_master_params(self):
"""
Update the master parameters for AMP training.
"""
- pass
diff --git a/colossalai/interface/optimizer.py b/colossalai/interface/optimizer.py
index bc270b1d9c89..95d11087bece 100644
--- a/colossalai/interface/optimizer.py
+++ b/colossalai/interface/optimizer.py
@@ -22,7 +22,7 @@ def parameters(self):
params = []
for group in self.param_groups:
- params += group['params']
+ params += group["params"]
return params
@property
@@ -82,12 +82,14 @@ def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
"""
nn.utils.clip_grad_value_(self.parameters, clip_value, *args, **kwargs)
- def clip_grad_by_norm(self,
- max_norm: Union[float, int],
- norm_type: Union[float, int] = 2.0,
- error_if_nonfinite: bool = False,
- *args,
- **kwargs) -> Tensor:
+ def clip_grad_by_norm(
+ self,
+ max_norm: Union[float, int],
+ norm_type: Union[float, int] = 2.0,
+ error_if_nonfinite: bool = False,
+ *args,
+ **kwargs,
+ ) -> Tensor:
"""
Clips gradient norm of an iterable of parameters.
@@ -113,7 +115,8 @@ def scale_loss(self, loss: Tensor):
loss (Tensor): The loss to be scaled.
"""
raise NotImplementedError(
- "The method scale_loss is only available for optimizers with mixed precision training")
+ "The method scale_loss is only available for optimizers with mixed precision training"
+ )
def unscale_grad(self):
"""
@@ -122,7 +125,8 @@ def unscale_grad(self):
Note: Only available for optimizers with mixed precision training.
"""
raise NotImplementedError(
- "The method unscale_grad is only available for optimizers with mixed precision training")
+ "The method unscale_grad is only available for optimizers with mixed precision training"
+ )
def unwrap(self):
"""
diff --git a/colossalai/interface/pretrained.py b/colossalai/interface/pretrained.py
new file mode 100644
index 000000000000..2f6bc10cd132
--- /dev/null
+++ b/colossalai/interface/pretrained.py
@@ -0,0 +1,16 @@
+from typing import Optional
+
+from torch.nn import Module
+
+__all__ = [
+ "get_pretrained_path",
+ "set_pretrained_path",
+]
+
+
+def get_pretrained_path(model: Module) -> Optional[str]:
+ return getattr(model, "_pretrained", None)
+
+
+def set_pretrained_path(model: Module, path: str) -> None:
+ setattr(model, "_pretrained", path)
diff --git a/colossalai/kernel/__init__.py b/colossalai/kernel/__init__.py
index a99cb497c3e7..8933fc0a3c2f 100644
--- a/colossalai/kernel/__init__.py
+++ b/colossalai/kernel/__init__.py
@@ -1,14 +1,7 @@
from .cuda_native import FusedScaleMaskSoftmax, LayerNorm, MultiHeadAttention
-from .triton import llama_context_attn_fwd, bloom_context_attn_fwd
-from .triton import softmax
-from .triton import copy_kv_cache_to_dest
__all__ = [
"LayerNorm",
"FusedScaleMaskSoftmax",
"MultiHeadAttention",
- "llama_context_attn_fwd",
- "bloom_context_attn_fwd",
- "softmax",
- "copy_kv_cache_to_dest",
]
diff --git a/colossalai/kernel/cuda_native/__init__.py b/colossalai/kernel/cuda_native/__init__.py
index e0136d86e561..f8a974b5fb26 100644
--- a/colossalai/kernel/cuda_native/__init__.py
+++ b/colossalai/kernel/cuda_native/__init__.py
@@ -4,6 +4,10 @@
from .scaled_softmax import AttnMaskType, FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax
__all__ = [
- 'LayerNorm', 'MultiHeadAttention', 'FusedScaleMaskSoftmax', 'ScaledUpperTriangMaskedSoftmax', 'ColoAttention',
- 'AttnMaskType'
+ "LayerNorm",
+ "MultiHeadAttention",
+ "FusedScaleMaskSoftmax",
+ "ScaledUpperTriangMaskedSoftmax",
+ "ColoAttention",
+ "AttnMaskType",
]
diff --git a/colossalai/kernel/cuda_native/csrc/compat.h b/colossalai/kernel/cuda_native/csrc/compat.h
index 00066dc95475..a62beef91a8a 100644
--- a/colossalai/kernel/cuda_native/csrc/compat.h
+++ b/colossalai/kernel/cuda_native/csrc/compat.h
@@ -7,4 +7,4 @@
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
-#endif
\ No newline at end of file
+#endif
diff --git a/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cu b/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cu
new file mode 100644
index 000000000000..2b1b366b1c02
--- /dev/null
+++ b/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cu
@@ -0,0 +1,63 @@
+// Adapted from turboderp exllama: https://github.com/turboderp/exllama
+
+#include "column_remap.cuh"
+#include "util.cuh"
+
+const int SHUF_BLOCKSIZE_X = 256;
+const int SHUF_BLOCKSIZE_Y = 16;
+
+__global__ void column_remap_kernel
+(
+ const half* __restrict__ x,
+ half* __restrict__ x_new,
+ const int x_width,
+ const int x_height,
+ const uint32_t* x_map
+)
+{
+ int x_column = SHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x;
+ int x_row = SHUF_BLOCKSIZE_Y * blockIdx.y;
+ if (x_column >= x_width) return;
+ //if (x_row >= x_height) return;
+
+ int x_stride = x_width;
+ int x_idx = x_row * x_stride + x_column;
+
+ int x_row_end = min(x_row + SHUF_BLOCKSIZE_Y, x_height);
+ int x_idx_end = x_row_end * x_stride + x_column;
+
+ int s_column = x_map[x_column];
+ int s_idx = x_row * x_stride + s_column;
+
+ while (x_idx < x_idx_end)
+ {
+ x_new[x_idx] = x[s_idx];
+ x_idx += x_stride;
+ s_idx += x_stride;
+ }
+}
+
+// Remap columns in x to correspond to sequential group index before matmul
+//
+// perform x -> seq_x such that seq_x @ seq_w == x @ w
+
+void column_remap_cuda
+(
+ const half* x,
+ half* x_new,
+ const int x_height,
+ const int x_width,
+ const uint32_t* x_map
+)
+{
+ dim3 threads(SHUF_BLOCKSIZE_X, 1, 1);
+
+ dim3 blocks
+ (
+ (x_width + SHUF_BLOCKSIZE_X - 1) / SHUF_BLOCKSIZE_X,
+ (x_height + SHUF_BLOCKSIZE_Y - 1) / SHUF_BLOCKSIZE_Y,
+ 1
+ );
+
+ column_remap_kernel<<>>(x, x_new, x_width, x_height, x_map);
+}
diff --git a/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cuh b/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cuh
new file mode 100644
index 000000000000..6571c17d6fd5
--- /dev/null
+++ b/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cuh
@@ -0,0 +1,19 @@
+// Adapted from turboderp exllama: https://github.com/turboderp/exllama
+
+#ifndef _column_remap_cuh
+#define _column_remap_cuh
+
+#include
+#include
+#include
+
+void column_remap_cuda
+(
+ const half* x,
+ half* x_new,
+ const int x_height,
+ const int x_width,
+ const uint32_t* x_map
+);
+
+#endif
\ No newline at end of file
diff --git a/colossalai/kernel/cuda_native/csrc/gptq/cu_compat.cuh b/colossalai/kernel/cuda_native/csrc/gptq/cu_compat.cuh
new file mode 100644
index 000000000000..c5258813e147
--- /dev/null
+++ b/colossalai/kernel/cuda_native/csrc/gptq/cu_compat.cuh
@@ -0,0 +1,58 @@
+// Adapted from turboderp exllama: https://github.com/turboderp/exllama
+
+#ifndef _cuda_compat_cuh
+#define _cuda_compat_cuh
+
+// atomicAdd for half types, to support CC < 7.x
+
+__device__ __forceinline__ void atomicAdd_half(half* address, half val)
+{
+ unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
+ unsigned int old = *address_as_ui;
+ unsigned int assumed;
+
+ do
+ {
+ assumed = old;
+ __half_raw hsum;
+ hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
+ half tmpres = __hadd(hsum, val);
+ hsum = __half_raw(tmpres);
+ old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
+ old = atomicCAS(address_as_ui, assumed, old);
+ }
+ while (assumed != old);
+}
+
+// atomicAdd for half2 types
+
+__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
+{
+ unsigned int* address_as_ui = (unsigned int*)address;
+ unsigned int old = *address_as_ui;
+ unsigned int assumed;
+ do
+ {
+ assumed = old;
+ half2 old_val = *((half2*)&old);
+ half2 new_val = __hadd2(old_val, val);
+ old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
+ }
+ while (assumed != old);
+}
+
+//
+
+#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
+#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
+
+__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
+
+#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
+__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
+#endif
+
+#endif
+#endif
+
+#endif
diff --git a/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cu b/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cu
new file mode 100644
index 000000000000..4416027c8387
--- /dev/null
+++ b/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cu
@@ -0,0 +1,75 @@
+// Adapted from turboderp exllama: https://github.com/turboderp/exllama
+
+#define _cuda_buffers_cu
+#include "cuda_buffers.cuh"
+
+CudaBuffers* g_buffers[CUDA_MAX_DEVICES] = {NULL};
+// __constant__ half2 q4_table[16][256];
+// half2 q4_table_host[16][256];
+// bool q4_table_init = false;
+
+CudaBuffers::CudaBuffers
+(
+ int _device,
+ int _temp_state_size,
+ half* _temp_state,
+ half* _temp_dq
+) :
+ device(_device),
+ temp_state_size(_temp_state_size),
+ temp_state(_temp_state),
+ temp_dq(_temp_dq)
+{
+ cudaSetDevice(_device);
+
+ cudaStreamCreate(&alt_stream_1);
+ cudaStreamCreate(&alt_stream_2);
+ cudaStreamCreate(&alt_stream_3);
+ cudaEventCreate(&alt_stream_1_done);
+ cudaEventCreate(&alt_stream_2_done);
+ cudaEventCreate(&alt_stream_3_done);
+}
+
+CudaBuffers::~CudaBuffers()
+{
+ cudaStreamDestroy(alt_stream_1);
+ cudaStreamDestroy(alt_stream_2);
+ cudaStreamDestroy(alt_stream_3);
+ cudaEventDestroy(alt_stream_1_done);
+ cudaEventDestroy(alt_stream_2_done);
+ cudaEventDestroy(alt_stream_3_done);
+}
+
+CudaBuffers* get_buffers(const int device_index)
+{
+ return g_buffers[device_index];
+}
+
+void prepare_buffers_cuda
+(
+ int _device,
+ int _temp_state_size,
+ half* _temp_state,
+ half* _temp_dq
+)
+{
+ CudaBuffers* buffers = new CudaBuffers
+ (
+ _device,
+ _temp_state_size,
+ _temp_state,
+ _temp_dq
+ );
+
+ g_buffers[_device] = buffers;
+}
+
+void cleanup_buffers_cuda()
+{
+ for (int i = 0; i < CUDA_MAX_DEVICES; i++)
+ {
+ if (!g_buffers[i]) continue;
+ delete g_buffers[i];
+ g_buffers[i] = NULL;
+ }
+}
diff --git a/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cuh b/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cuh
new file mode 100644
index 000000000000..0bf2057c665c
--- /dev/null
+++ b/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cuh
@@ -0,0 +1,55 @@
+// Adapted from turboderp exllama: https://github.com/turboderp/exllama
+
+#ifndef _cuda_buffers_cuh
+#define _cuda_buffers_cuh
+
+#include
+#include
+#include
+#include
+
+const int CUDA_MAX_DEVICES = 16;
+
+// #ifndef _cuda_buffers_cu
+// extern __constant__ half2 q4_table[16][256];
+// #endif
+
+class CudaBuffers
+{
+public:
+ int device;
+
+ half* temp_state; // [max_hidden_rows * intermediate_size]
+ int temp_state_size;
+ half* temp_dq; // size of largest quant tensor * 8
+
+ cudaStream_t alt_stream_1;
+ cudaStream_t alt_stream_2;
+ cudaStream_t alt_stream_3;
+ cudaEvent_t alt_stream_1_done;
+ cudaEvent_t alt_stream_2_done;
+ cudaEvent_t alt_stream_3_done;
+
+ CudaBuffers
+ (
+ int _device,
+ int _temp_state_size,
+ half* _temp_state,
+ half* _temp_dq
+ );
+ ~CudaBuffers();
+};
+
+CudaBuffers* get_buffers(const int device_index);
+
+void prepare_buffers_cuda
+(
+ int _device,
+ int _temp_state_size,
+ half* _temp_state,
+ half* _temp_dq
+);
+
+void cleanup_buffers_cuda();
+
+#endif
diff --git a/colossalai/kernel/cuda_native/csrc/gptq/hip_compat.cuh b/colossalai/kernel/cuda_native/csrc/gptq/hip_compat.cuh
new file mode 100644
index 000000000000..5cd2e8553ef6
--- /dev/null
+++ b/colossalai/kernel/cuda_native/csrc/gptq/hip_compat.cuh
@@ -0,0 +1,49 @@
+// Adapted from turboderp exllama: https://github.com/turboderp/exllama
+
+#ifndef _hip_compat_cuh
+#define _hip_compat_cuh
+
+// Workaround for a bug in hipamd, backported from upstream.
+__device__ __forceinline__ __half __compat_hrcp(__half x) {
+ return __half_raw{
+ static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))};
+}
+
+__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) {
+ return _Float16_2{static_cast<_Float16>(__builtin_amdgcn_rcph(x.x)),
+ static_cast<_Float16>(__builtin_amdgcn_rcph(x.y))};
+}
+
+#define hrcp __compat_hrcp
+#define h2rcp __compat_h2rcp
+
+// Workaround for hipify_python using rocblas instead of hipblas.
+__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
+ hipblasOperation_t transA,
+ hipblasOperation_t transB,
+ int m,
+ int n,
+ int k,
+ const half* alpha,
+ const half* AP,
+ int lda,
+ const half* BP,
+ int ldb,
+ const half* beta,
+ half* CP,
+ int ldc) {
+ return hipblasHgemm(handle, transA, transB, m, n, k,
+ reinterpret_cast(alpha),
+ reinterpret_cast(AP), lda,
+ reinterpret_cast(BP), ldb,
+ reinterpret_cast(beta),
+ reinterpret_cast(CP), ldc);
+}
+
+#define rocblas_handle hipblasHandle_t
+#define rocblas_operation_none HIPBLAS_OP_N
+#define rocblas_get_stream hipblasGetStream
+#define rocblas_set_stream hipblasSetStream
+#define rocblas_hgemm __compat_hipblasHgemm
+
+#endif
diff --git a/colossalai/kernel/cuda_native/csrc/gptq/linear_gptq.cpp b/colossalai/kernel/cuda_native/csrc/gptq/linear_gptq.cpp
new file mode 100644
index 000000000000..bcc0e43901de
--- /dev/null
+++ b/colossalai/kernel/cuda_native/csrc/gptq/linear_gptq.cpp
@@ -0,0 +1,254 @@
+// Adapted from turboderp exllama: https://github.com/turboderp/exllama
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include "util.cuh"
+#include "tuning.h"
+#include "cuda_buffers.cuh"
+#include "q4_matrix.cuh"
+#include "q4_matmul.cuh"
+#include "column_remap.cuh"
+
+// Check CUDA return code. We don't want to include Torch headers in the .cu files because parsing them adds almost a
+// minute to the compile time on a 12900K. Also passing exceptions back to Python is super tricky, so in place of
+// exceptions, CUDA functions return with a cudaError_t which we can parse and dump to the console.
+
+void check_cuda(cudaError_t ret)
+{
+ switch (ret)
+ {
+ case cudaSuccess:
+ break;
+
+ case cudaUnspecified:
+ printf(" **** Unspecified error\n");
+ TORCH_CHECK(false, "CUDA error");
+ break;
+
+ default:
+ printf(" **** CUDA error\n"); \
+ printf(" **** %s\n", cudaGetErrorString(ret)); \
+ TORCH_CHECK(false, "CUDA error"); \
+ break;
+ }
+}
+
+// Some decluttering macros
+
+#define STRINGIFY_(__x) #__x
+#define STRINGIFY(__x) STRINGIFY_(__x)
+#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
+#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
+#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
+#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
+#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) TORCH_CHECK((__x).size(__dim_x) % __mod == 0, #__x ".shape[" STRINGIFY(__dim_x) "] must be a multiple of " STRINGIFY(__mod))
+#define TORCH_CHECK_BUFFER_SIZE(__buffer, __minimum_size) TORCH_CHECK((__buffer).numel() >= __minimum_size, #__buffer " is too small")
+
+#define TORCH_CHECK_DEVICE_INDEX(__index) \
+do { \
+ TORCH_CHECK(__index >= 0, "no device index"); \
+ TORCH_CHECK(__index < CUDA_MAX_DEVICES, "invalid device index"); \
+} while(0)
+
+#define TORCH_CHECK_QUANT(__w, __w_scales, __w_zeros, __seq_g_idx, __x_map) \
+do { \
+ TORCH_CHECK_DTYPE(__w, kInt); \
+ TORCH_CHECK_DTYPE(__w_scales, kHalf); \
+ TORCH_CHECK_DTYPE(__w_zeros, kInt); \
+ TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \
+ TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \
+ TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \
+ TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \
+} while(0)
+
+int get_groupsize(torch::Tensor w, torch::Tensor w_zeros)
+{
+ int groupsize = w.size(0) * 8 / w_zeros.size(0);
+ TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, "w.shape[-2] must be a multiple of zeros.shape[-2]")
+ return groupsize;
+}
+
+
+// Tuning parameters
+
+ExLlamaTuning tuningParams;
+
+void set_tuning_params
+(
+ int matmul_recons_thd,
+ bool matmul_fused_remap,
+ bool matmul_no_half2
+)
+{
+ tuningParams.matmul_recons_thd = matmul_recons_thd;
+ tuningParams.matmul_fused_remap = matmul_fused_remap;
+ tuningParams.matmul_no_half2 = matmul_no_half2;
+}
+
+
+// Release all unmanaged objects allocated by the extension
+
+void cleanup()
+{
+ cleanup_buffers_cuda();
+ g_q4_free_matrices();
+}
+
+
+// Prepare buffers for forward pass
+
+void prepare_buffers
+(
+ torch::Device device,
+ torch::Tensor temp_state,
+ torch::Tensor temp_dq
+)
+{
+ int device_index = device.index();
+ TORCH_CHECK_DEVICE_INDEX(device_index);
+ const at::cuda::OptionalCUDAGuard device_guard(device);
+
+ prepare_buffers_cuda
+ (
+ device_index,
+ // buffer size used for sanity checks
+ temp_state.numel(),
+ (half*) temp_state.data_ptr(),
+ (half*) temp_dq.data_ptr()
+ );
+}
+
+
+// Create Q4Matrix, return handle
+
+uintptr_t make_q4
+(
+ torch::Tensor qweight,
+ torch::Tensor qzeros,
+ torch::Tensor scales,
+ torch::Tensor g_idx,
+ int device
+)
+{
+ TORCH_CHECK_DTYPE(qweight, kInt);
+ TORCH_CHECK_DTYPE(qzeros, kInt);
+ TORCH_CHECK_DTYPE(scales, kHalf);
+ TORCH_CHECK_DTYPE_OPT(g_idx, kInt);
+ TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8);
+ TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1);
+ TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1);
+
+ int width = qweight.size(1);
+ int height = qweight.size(0) * 8;
+ int groups = qzeros.size(0);
+
+ Q4Matrix* m = new Q4Matrix
+ (
+ height,
+ width,
+ groups,
+
+ (uint32_t*) qweight.data_ptr(),
+ (uint32_t*) qzeros.data_ptr(),
+ (half*) scales.data_ptr(),
+ g_idx.device().is_meta() ? NULL : (uint32_t*) g_idx.data_ptr(),
+
+ device
+ );
+
+ g_q4_keep_matrix(m);
+ return reinterpret_cast (m);
+}
+
+
+// Matmul half @ quant -> half
+
+void q4_matmul
+(
+ torch::Tensor x,
+ uintptr_t w,
+ torch::Tensor out
+)
+{
+ Q4Matrix* wm = reinterpret_cast (w);
+
+ TORCH_CHECK_DTYPE(x, kHalf);
+ TORCH_CHECK_DTYPE(out, kHalf);
+ TORCH_CHECK_SHAPES(x, 0, out, 0, 1);
+ TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes")
+
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
+
+ int x_height = x.size(0);
+
+ if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd)
+ {
+ q4_matmul_cuda
+ (
+ &tuningParams,
+ (half*) x.data_ptr(),
+ x_height,
+ wm,
+ (half*) out.data_ptr()
+ );
+ }
+ else
+ {
+ q4_matmul_recons_cuda
+ (
+ &tuningParams,
+ (half*) x.data_ptr(),
+ x_height,
+ wm,
+ (half*) out.data_ptr(),
+ at::cuda::getCurrentCUDABlasHandle()
+ );
+ }
+}
+
+
+// Remap columns in half tensor
+
+void column_remap
+(
+ torch::Tensor x,
+ torch::Tensor x_new,
+ torch::Tensor x_map
+)
+{
+ TORCH_CHECK_DTYPE(x, kHalf);
+ TORCH_CHECK_DTYPE(x_new, kHalf);
+ TORCH_CHECK_DTYPE(x_map, kInt);
+ TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1);
+
+ int height = x.size(0);
+ int width = x.size(1);
+
+ TORCH_CHECK_BUFFER_SIZE(x_new, height * width);
+
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
+
+ column_remap_cuda
+ (
+ (half*) x.data_ptr(),
+ (half*) x_new.data_ptr(),
+ height,
+ width,
+ (uint32_t*) x_map.data_ptr()
+ );
+}
+
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
+{
+ m.def("set_tuning_params", &set_tuning_params, "set_tuning_params");
+ m.def("prepare_buffers", &prepare_buffers, "prepare_buffers");
+ m.def("cleanup", &cleanup, "cleanup");
+ m.def("make_q4", &make_q4, "make_q4");
+ m.def("q4_matmul", &q4_matmul, "q4_matmul");
+}
diff --git a/colossalai/kernel/cuda_native/csrc/gptq/matrix.cuh b/colossalai/kernel/cuda_native/csrc/gptq/matrix.cuh
new file mode 100644
index 000000000000..2fd5ab0b36cd
--- /dev/null
+++ b/colossalai/kernel/cuda_native/csrc/gptq/matrix.cuh
@@ -0,0 +1,294 @@
+// Adapted from turboderp exllama: https://github.com/turboderp/exllama
+
+#ifndef _matrix_cuh
+#define _matrix_cuh
+
+#include
+#include
+
+class MatrixView_half
+{
+public:
+ const half* data;
+ const int height;
+ const int width;
+
+ __device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
+ : data(data), height(height), width(width)
+ { }
+
+ __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
+ __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
+ __device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
+ __device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
+};
+
+class MatrixView_half_rw
+{
+public:
+ half* data;
+ const int height;
+ const int width;
+
+ __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
+ : data(data), height(height), width(width)
+ { }
+
+ __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
+ __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
+ __device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
+ __device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
+ __device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
+};
+
+class MatrixView_q4_row
+{
+public:
+ const uint32_t* data;
+ const int height;
+ const int width;
+
+ __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
+ : data(data), height(height), width(width)
+ { }
+
+ __device__ __forceinline__ int item(int row, int column) const
+ {
+ int shift = (column & 0x07) * 4;
+ return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
+ }
+};
+
+class MatrixView_q4_column
+{
+public:
+ const uint32_t* data;
+ const int height;
+ const int width;
+
+ __device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)
+ : data(data), height(height), width(width)
+ { }
+
+ __device__ __forceinline__ int item(int row, int column) const
+ {
+ int shift = (row & 0x07) * 4;
+ return (data[row / 8 * width + column] >> shift) & 0x0f;
+ }
+
+ __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; }
+ __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
+};
+
+// TODO: Rewrite all these dot product functions using functors or something, move to q4_matmul.cu
+
+// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale
+
+__device__ __forceinline__ half2 dot_product_8
+(
+ const half2 acc,
+ MatrixView_half& h_,
+ const int h_row,
+ const int h_column, // divisible by 8
+ MatrixView_q4_column& v_,
+ const int v_row, // divisible by 8
+ const int v_column,
+ const half2 v_scale_2,
+ const uint32_t v_zero, // + 1 (!!)
+ const int count
+)
+{
+ const half2* h_ptr = (const half2*) h_.item_ptr(h_row, h_column);
+ const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
+ half2 result = acc;
+
+ for (int i = 0; i < count; i++)
+ {
+ uint32_t v_read = *v_ptr; v_ptr += v_.width;
+
+ half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
+ half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
+ half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
+ half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
+ half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
+ half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
+ half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
+ half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
+
+ half2 v_01 = __halves2half2(v_0, v_1);
+ half2 v_23 = __halves2half2(v_2, v_3);
+ half2 v_45 = __halves2half2(v_4, v_5);
+ half2 v_67 = __halves2half2(v_6, v_7);
+
+// half2 v_01 = q4_table[v_zero - 1][(v_read ) & 0xff]; // (constant memory is too slow apparently)
+// half2 v_23 = q4_table[v_zero - 1][(v_read >> 8) & 0xff];
+// half2 v_45 = q4_table[v_zero - 1][(v_read >> 16) & 0xff];
+// half2 v_67 = q4_table[v_zero - 1][(v_read >> 24) ];
+
+ half2 tmp = __hmul2(*h_ptr++, v_01);
+ tmp = __hfma2(*h_ptr++, v_23, tmp);
+ tmp = __hfma2(*h_ptr++, v_45, tmp);
+ tmp = __hfma2(*h_ptr++, v_67, tmp);
+ result = __hfma2(v_scale_2, tmp, result);
+ }
+
+ return result;
+}
+
+__device__ __forceinline__ half dot_product_8_h
+(
+ const half acc,
+ MatrixView_half& h_,
+ const int h_row,
+ const int h_column, // divisible by 8
+ MatrixView_q4_column& v_,
+ const int v_row, // divisible by 8
+ const int v_column,
+ const half v_scale,
+ const uint32_t v_zero, // + 1 (!!)
+ const int count
+)
+{
+ const half* h_ptr = h_.item_ptr(h_row, h_column);
+ const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
+ half result = acc;
+
+ for (int i = 0; i < count; i++)
+ {
+ uint32_t v_read = *v_ptr; v_ptr += v_.width;
+
+ half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
+ half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
+ half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
+ half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
+ half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
+ half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
+ half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
+ half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
+
+ half tmp = __hmul(*h_ptr++, v_0);
+ tmp = __hfma(*h_ptr++, v_1, tmp);
+ tmp = __hfma(*h_ptr++, v_2, tmp);
+ tmp = __hfma(*h_ptr++, v_3, tmp);
+ tmp = __hfma(*h_ptr++, v_4, tmp);
+ tmp = __hfma(*h_ptr++, v_5, tmp);
+ tmp = __hfma(*h_ptr++, v_6, tmp);
+ tmp = __hfma(*h_ptr++, v_7, tmp);
+ result = __hfma(v_scale, tmp, result);
+ }
+
+ return result;
+}
+
+// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map
+
+__device__ __forceinline__ half2 dot_product_8_x_map
+(
+ const half2 acc,
+ MatrixView_half& h_,
+ const int h_row,
+ const int h_column, // divisible by 8
+ MatrixView_q4_column& v_,
+ const int v_row, // divisible by 8
+ const int v_column,
+ const half2 v_scale_2,
+ const uint32_t v_zero, // + 1 (!!)
+ const int count,
+ const uint32_t* x_map
+)
+{
+ const half* h_ptr = h_.item_ptr(h_row, 0);
+ const uint32_t* x_map_ptr = x_map + h_column;
+ const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
+ half2 result = acc;
+
+ for (int i = 0; i < count; i++)
+ {
+ uint32_t v_read = *v_ptr; v_ptr += v_.width;
+
+ half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
+ half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
+ half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
+ half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
+ half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
+ half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
+ half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
+ half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
+
+ half2 v_01 = __halves2half2(v_0, v_1);
+ half2 v_23 = __halves2half2(v_2, v_3);
+ half2 v_45 = __halves2half2(v_4, v_5);
+ half2 v_67 = __halves2half2(v_6, v_7);
+
+ half h_0 = h_ptr[*x_map_ptr++];
+ half h_1 = h_ptr[*x_map_ptr++];
+ half h_2 = h_ptr[*x_map_ptr++];
+ half h_3 = h_ptr[*x_map_ptr++];
+ half h_4 = h_ptr[*x_map_ptr++];
+ half h_5 = h_ptr[*x_map_ptr++];
+ half h_6 = h_ptr[*x_map_ptr++];
+ half h_7 = h_ptr[*x_map_ptr++];
+
+ half2 h_01 = __halves2half2(h_0, h_1);
+ half2 h_23 = __halves2half2(h_2, h_3);
+ half2 h_45 = __halves2half2(h_4, h_5);
+ half2 h_67 = __halves2half2(h_6, h_7);
+
+ half2 tmp = __hmul2(h_01, v_01);
+ tmp = __hfma2(h_23, v_23, tmp);
+ tmp = __hfma2(h_45, v_45, tmp);
+ tmp = __hfma2(h_67, v_67, tmp);
+ result = __hfma2(v_scale_2, tmp, result);
+ }
+
+ return result;
+}
+
+__device__ __forceinline__ half dot_product_8_x_map_h
+(
+ const half acc,
+ MatrixView_half& h_,
+ const int h_row,
+ const int h_column, // divisible by 8
+ MatrixView_q4_column& v_,
+ const int v_row, // divisible by 8
+ const int v_column,
+ const half v_scale,
+ const uint32_t v_zero, // + 1 (!!)
+ const int count,
+ const uint32_t* x_map
+)
+{
+ const half* h_ptr = h_.item_ptr(h_row, 0);
+ const uint32_t* x_map_ptr = x_map + h_column;
+ const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
+ half result = acc;
+
+ for (int i = 0; i < count; i++)
+ {
+ uint32_t v_read = *v_ptr; v_ptr += v_.width;
+
+ half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
+ half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
+ half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
+ half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
+ half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
+ half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
+ half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
+ half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
+
+ half tmp = __hmul(h_ptr[*x_map_ptr++], v_0);
+ tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp);
+ tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp);
+ tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp);
+ tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp);
+ tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp);
+ tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp);
+ tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp);
+ result = __hfma(v_scale, tmp, result);
+ }
+
+ return result;
+}
+
+#endif
diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cu b/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cu
new file mode 100644
index 000000000000..f47daeb0e877
--- /dev/null
+++ b/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cu
@@ -0,0 +1,260 @@
+// Adapted from turboderp exllama: https://github.com/turboderp/exllama
+
+#include "q4_matmul.cuh"
+#include "column_remap.cuh"
+#include "util.cuh"
+#include "matrix.cuh"
+#include "cu_compat.cuh"
+#include "cuda_buffers.cuh"
+#if defined(USE_ROCM)
+#include "hip_compat.cuh"
+#endif
+
+const int THREADS_X = 32; // Block size and thread count along columns in w and out
+const int THREADS_Y = 1; // Block size and thread count along rows in x and out
+
+typedef void (*fp_q4_matmul_kernel)
+(
+ const half*,
+ const uint32_t*,
+ half*,
+ const half*,
+ const uint32_t*,
+ const int,
+ const int,
+ const int,
+ const int,
+ const int,
+ const uint32_t*,
+ bool
+);
+
+template
+__global__ void q4_matmul_kernel
+(
+ const half* __restrict__ x,
+ const uint32_t* __restrict__ w,
+ half* __restrict__ out,
+ const half* __restrict__ w_scales,
+ const uint32_t* __restrict__ w_zeros,
+ const int height,
+ const int dim,
+ const int width,
+ const int groupsize,
+ const int block_size_z,
+ const uint32_t* __restrict__ x_map,
+ bool no_zero
+)
+{
+ // Start of block
+
+ int x_column = block_size_z * blockIdx.z;
+ int x_column_end = min(dim, block_size_z * (blockIdx.z + 1));
+
+ int w_column = THREADS_X * blockIdx.x + threadIdx.x;
+ int x_row = THREADS_Y * blockIdx.y + threadIdx.y;
+
+ int iterations = (x_column_end - x_column) / 8;
+
+ // Views
+
+ MatrixView_half x_(x, height, dim);
+ MatrixView_half w_scales_(w_scales, dim / groupsize, width);
+ MatrixView_q4_row w_zeros_(w_zeros, dim / groupsize, width);
+ MatrixView_q4_column w_(w, dim, width);
+ MatrixView_half_rw out_(out, height, width);
+
+ // Zero output
+
+ if (!no_zero && blockIdx.z == 0 && (threadIdx.x & 1) == 0)
+ {
+ *((uint32_t*) out_.item_ptr(x_row, w_column)) = 0;
+ __syncthreads();
+ }
+
+ // Loop over part of x row (and w column)
+
+ half2 acc = {};
+ half acc_h = {};
+
+ if constexpr (use_groupsize)
+ {
+ // For quant matrices where groupsize divides BLOCK_SIZE_Z we always start on a group boundary, so this
+ // could be slightly faster
+
+ for (int k = x_column, group = x_column / groupsize; k < x_column + iterations * 8; group++, k += groupsize)
+ {
+ if constexpr (use_half2)
+ {
+ half2 w_scale = w_scales_.item_half2half2(group, w_column);
+ uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
+
+ if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
+ else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
+ }
+ else
+ {
+ half w_scale = w_scales_.item(group, w_column);
+ uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
+
+ if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
+ else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
+ }
+ }
+ }
+ else
+ {
+ // Otherwise assume groupsize is a multiple of 8, do 8 columns per iteration and trust the cache
+
+ for (int k = x_column; k < x_column + iterations * 8; k += 8)
+ {
+ if constexpr (use_half2)
+ {
+ int group = k / groupsize;
+ half2 w_scale = w_scales_.item_half2half2(group, w_column);
+ uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
+
+ if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
+ else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
+ }
+ else
+ {
+ int group = k / groupsize;
+ half w_scale = w_scales_.item(group, w_column);
+ uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
+
+ if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
+ else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
+ }
+ }
+ }
+
+ // Add to block result
+
+ if constexpr (use_half2)
+ {
+ half result = __hadd(__low2half(acc), __high2half(acc));
+ atomicAdd(out_.item_ptr(x_row, w_column), result);
+ }
+ else
+ {
+ atomicAdd(out_.item_ptr(x_row, w_column), acc_h);
+ }
+}
+
+fp_q4_matmul_kernel q4_matmul_kernel_pick(ExLlamaTuning* tuningParams, int block_size_z, int groupsize, uint32_t* x_map)
+{
+ //
+ if (tuningParams->matmul_no_half2) {
+ if (block_size_z % groupsize == 0) {
+ if (x_map) return q4_matmul_kernel;
+ else return q4_matmul_kernel;
+ } else {
+ if (x_map) return q4_matmul_kernel;
+ else return q4_matmul_kernel;
+ }
+ } else {
+ if (block_size_z % groupsize == 0)
+ {
+ if (x_map) return q4_matmul_kernel;
+ else return q4_matmul_kernel;
+ } else {
+ if (x_map) return q4_matmul_kernel;
+ else return q4_matmul_kernel;
+ }
+ }
+};
+
+// Compute y = x @ w
+
+void q4_matmul_cuda
+(
+ ExLlamaTuning* tuningParams,
+ const half* x,
+ const int x_height,
+ const Q4Matrix* w,
+ half* out,
+ bool no_zero,
+ cudaStream_t alt_stream
+)
+{
+ int height = x_height;
+ int dim = w->height;
+ int width = w->width;
+
+ cudaSetDevice(w->device);
+
+ uint32_t* x_map = w->cuda_x_map;
+ const half* x_mapped = x;
+ if (x_map && !tuningParams->matmul_fused_remap && !alt_stream)
+ {
+ CudaBuffers* buffers = get_buffers(w->device);
+ column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map);
+ x_mapped = buffers->temp_state;
+ x_map = NULL;
+ }
+
+ int block_size_z;
+ if (w->width == 4096) block_size_z = 384; // 7B
+ else if (w->width == 11008) block_size_z = 256;
+ else if (w->width == 5120) block_size_z = 384; // 13B
+ else if (w->width == 13824) block_size_z = 256;
+ else if (w->width == 6656) block_size_z = 256; // 33B
+ else if (w->width == 17920) block_size_z = 128;
+ else block_size_z = 256;
+
+ //if (!no_zero) cudaMemsetAsync(out, 0, x_height * w->width * sizeof(half));
+
+ dim3 threads(THREADS_X, THREADS_Y, 1);
+
+ dim3 blocks
+ (
+ (width + threads.x - 1) / threads.x,
+ (height + threads.y - 1) / threads.y,
+ (dim + block_size_z - 1) / block_size_z
+ );
+
+ fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map);
+
+ kernel<<>> (x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero);
+}
+
+void q4_matmul_recons_cuda
+(
+ ExLlamaTuning* tuningParams,
+ const half* x,
+ const int x_height,
+ Q4Matrix* w,
+ half* out,
+ const cublasHandle_t handle,
+ bool no_zero
+)
+{
+ int height = x_height;
+ int dim = w->height;
+ int width = w->width;
+
+ cudaSetDevice(w->device);
+ CudaBuffers* buffers = get_buffers(w->device);
+
+ const half* x_mapped = x;
+ if (w->cuda_x_map)
+ {
+ TORCH_CHECK(buffers->temp_state_size >= x_height * dim, "temp_state buffer is too small");
+ column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map);
+ x_mapped = buffers->temp_state;
+ }
+
+ w->reconstruct(buffers->temp_dq);
+
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700
+ const float alpha = 1.0f;
+ const float beta = no_zero ? 1.0f : 0.0f;
+ cublasSgemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, CUDA_R_16F, width,
+ x_mapped, CUDA_R_16F, dim, &beta, out, CUDA_R_16F, width);
+#else
+ const half alpha = __float2half(1.0f);
+ const half beta = no_zero ? __float2half(1.0f) : __float2half(0.0f);
+ cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, width, x_mapped, dim, &beta, out, width);
+#endif
+}
diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cuh b/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cuh
new file mode 100644
index 000000000000..09f3e1a63362
--- /dev/null
+++ b/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cuh
@@ -0,0 +1,43 @@
+// Adapted from turboderp exllama: https://github.com/turboderp/exllama
+
+#ifndef _q4_matmul_cuh
+#define _q4_matmul_cuh
+
+#include
+#include
+#include
+#include
+#include
+
+#include "q4_matrix.cuh"
+#include "tuning.h"
+
+// Workaround for hipify_python using rocblas instead of hipblas.
+#if defined(USE_ROCM)
+#include
+#define rocblas_handle hipblasHandle_t
+#endif
+
+void q4_matmul_cuda
+(
+ ExLlamaTuning* tuningParams,
+ const half* x,
+ const int x_height,
+ const Q4Matrix* w,
+ half* out,
+ bool no_zero = false,
+ cudaStream_t alt_stream = NULL
+);
+
+void q4_matmul_recons_cuda
+(
+ ExLlamaTuning* tuningParams,
+ const half* x,
+ const int x_height,
+ Q4Matrix* w,
+ half* out,
+ const cublasHandle_t handle,
+ bool no_zero = false
+);
+
+#endif
diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cu b/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cu
new file mode 100644
index 000000000000..9c61143f565e
--- /dev/null
+++ b/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cu
@@ -0,0 +1,225 @@
+// Adapted from turboderp exllama: https://github.com/turboderp/exllama
+
+#include "q4_matrix.cuh"
+#include
+#include "util.cuh"
+#include "matrix.cuh"
+
+using namespace std;
+
+const int UNSHUF_BLOCKSIZE_X = 64;
+
+const int RECONS_THREADS_X = 64; // Block size and thread count along columns in out, each thread converts 1 column
+const int RECONS_THREADS_Y = 1; // Block size and thread count along rows in x and out, each thread converts 8 rows
+
+vector g_q4_matrices;
+
+void g_q4_keep_matrix(Q4Matrix* m)
+{
+ g_q4_matrices.push_back(m);
+}
+
+void g_q4_free_matrices()
+{
+ for (const auto& m : g_q4_matrices) delete m;
+ g_q4_matrices.clear();
+}
+
+Q4Matrix::Q4Matrix
+(
+ const int _height,
+ const int _width,
+ const int _groups,
+
+ uint32_t* _qweight,
+ uint32_t* _qzeros,
+ half* _scales,
+ uint32_t* _g_idx,
+
+ const int _device
+) :
+ height(_height),
+ width(_width),
+ groups(_groups),
+ device(_device)
+{
+ cudaSetDevice(device);
+
+ cuda_qweight = _qweight;
+ cuda_qzeros = _qzeros;
+ cuda_scales = _scales;
+
+ groupsize = height / groups;
+
+ if (_g_idx) make_sequential(_g_idx);
+}
+
+Q4Matrix::~Q4Matrix()
+{
+}
+
+// Make sequential
+
+__global__ void make_sequential_kernel
+(
+ const uint32_t* __restrict__ w,
+ uint32_t* __restrict__ w_new,
+ const uint32_t* __restrict__ x_map,
+ const int w_height,
+ const int w_width
+)
+{
+ const uint64_t* w2 = (uint64_t*) w;
+ uint64_t* w_new2 = (uint64_t*) w_new;
+ int w2_stride = w_width >> 1;
+
+ int w2_column = UNSHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x;
+ if (w2_column >= w2_stride) return;
+
+ int w_new2_row = blockIdx.y;
+
+ int x_map_idx = w_new2_row << 3;
+
+ uint64_t dst = 0;
+
+ #pragma unroll
+ for (int i = 0; i < 8; i++)
+ {
+ int source_row = x_map[x_map_idx++];
+
+ int w2_row = source_row >> 3;
+ int w2_subrow = source_row & 0x07;
+ int w2_row_shift = w2_subrow << 2;
+ int wnew2_row_shift = i << 2;
+
+ uint64_t src = w2[w2_row * w2_stride + w2_column];
+ src >>= w2_row_shift;
+ src &= 0x0000000f0000000f;
+ src <<= wnew2_row_shift;
+ dst |= src;
+ }
+
+ w_new2[w_new2_row * w2_stride + w2_column] = dst;
+}
+
+void Q4Matrix::make_sequential(const uint32_t* cpu_g_idx)
+{
+ uint32_t* cuda_new_qweight = NULL;
+ cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));
+ cudaMalloc(&cuda_x_map, height * sizeof(uint32_t)); // TODO: Should probably be allocated in PyTorch
+
+ uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t));
+ uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t));
+ uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t));
+
+ // Group histogram
+
+ for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++;
+
+ // Group map
+
+ for (int i = 0, acc = 0; i < groups; i++)
+ {
+ short tmp = cpu_g_idx_map[i];
+ cpu_g_idx_map[i] = acc;
+ acc += tmp;
+ }
+
+ // X map (inverse)
+
+ for (int row = 0; row < height; row++)
+ {
+ uint32_t target_group = cpu_g_idx[row];
+ uint32_t target_row = cpu_g_idx_map[target_group];
+ cpu_g_idx_map[target_group]++;
+ cpu_x_map_inv[row] = target_row;
+ }
+
+ // X map
+
+ for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row;
+
+ // Move to CUDA
+
+ cudaMemcpyAsync(cuda_x_map, cpu_x_map, height * sizeof(uint32_t), cudaMemcpyHostToDevice);
+
+ // Rearrange rows in w
+
+ dim3 threads(UNSHUF_BLOCKSIZE_X, 1, 1);
+ dim3 blocks
+ (
+ (width + UNSHUF_BLOCKSIZE_X * 2 - 1) / (UNSHUF_BLOCKSIZE_X * 2),
+ height / 8,
+ 1
+ );
+
+ make_sequential_kernel<<>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width);
+
+ // Replace qweights
+
+ cudaMemcpyAsync(cuda_qweight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice);
+
+ // Cleanup
+
+ cudaDeviceSynchronize();
+ cudaFree(cuda_new_qweight);
+ free(cpu_g_idx_map);
+ free(cpu_x_map);
+ free(cpu_x_map_inv);
+}
+
+__global__ void reconstruct_kernel
+(
+ const uint32_t* __restrict__ w,
+ half* __restrict__ out, // (y)
+ const half* __restrict__ w_scales,
+ const uint32_t* __restrict__ w_zeros,
+ const int height,
+ const int width,
+ const int groupsize
+)
+{
+ // Start of block
+
+ int column = RECONS_THREADS_X * blockIdx.x + threadIdx.x;
+ int row = (RECONS_THREADS_Y * blockIdx.y + threadIdx.y) * 8;
+ if (column >= width) return;
+
+ // Views
+
+ MatrixView_q4_column w_(w, height, width);
+ MatrixView_half_rw out_(out, height, width);
+ MatrixView_half w_scales_(w_scales, height / groupsize, width);
+ MatrixView_q4_row w_zeros_(w_zeros, height / groupsize, width);
+
+ // Groupsize version
+
+ int group = row / groupsize;
+
+ half w_scale = w_scales_.item(group, column);
+ uint32_t w_zero = w_zeros_.item(group, column) + 1;
+
+ uint32_t w_read = w_.item_uint32_t(row, column);
+ half* out_ptr = out_.item_ptr(row, column);
+
+ #pragma unroll
+ for (int s = 0; s < 32; s += 4)
+ {
+ half w_item = __hmul(__int2half_rn((int)((w_read >> s) & 0x0f) - w_zero), w_scale);
+ *out_ptr = w_item; out_ptr += out_.width;
+ }
+}
+
+void Q4Matrix::reconstruct(half* out)
+{
+ dim3 threads(RECONS_THREADS_X, RECONS_THREADS_Y, 1);
+
+ dim3 blocks
+ (
+ (width + threads.x - 1) / threads.x,
+ (height / 8 + threads.y - 1) / threads.y,
+ 1
+ );
+
+ reconstruct_kernel<<>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize);
+}
diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cuh b/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cuh
new file mode 100644
index 000000000000..50cb72a41518
--- /dev/null
+++ b/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cuh
@@ -0,0 +1,53 @@
+// Adapted from turboderp exllama: https://github.com/turboderp/exllama
+
+#ifndef _q4_matrix_cuh
+#define _q4_matrix_cuh
+
+#include
+#include
+#include
+
+class Q4Matrix
+{
+public:
+
+ int device;
+
+ int height;
+ int width;
+ int groups;
+ int groupsize;
+
+ uint32_t* cuda_qweight = NULL;
+ uint32_t* cuda_qzeros = NULL;
+ half* cuda_scales = NULL;
+ uint32_t* cuda_x_map = NULL;
+
+ Q4Matrix
+ (
+ const int _height,
+ const int _width,
+ const int _groups,
+
+ uint32_t* _qweight,
+ uint32_t* _qzeros,
+ half* _scales,
+ uint32_t* _g_idx,
+
+ const int _device
+ );
+
+ ~Q4Matrix();
+
+ void reconstruct(half* out);
+
+private:
+
+ void make_sequential(const uint32_t* cpu_g_idx);
+
+};
+
+void g_q4_keep_matrix(Q4Matrix* m);
+void g_q4_free_matrices();
+
+#endif
\ No newline at end of file
diff --git a/colossalai/kernel/cuda_native/csrc/gptq/tuning.h b/colossalai/kernel/cuda_native/csrc/gptq/tuning.h
new file mode 100644
index 000000000000..770ca46aa7c8
--- /dev/null
+++ b/colossalai/kernel/cuda_native/csrc/gptq/tuning.h
@@ -0,0 +1,13 @@
+// Adapted from turboderp exllama: https://github.com/turboderp/exllama
+
+#ifndef _tuning_h
+#define _tuning_h
+
+struct ExLlamaTuning
+{
+ int matmul_recons_thd;
+ bool matmul_fused_remap;
+ bool matmul_no_half2;
+};
+
+#endif
diff --git a/colossalai/kernel/cuda_native/csrc/gptq/util.cuh b/colossalai/kernel/cuda_native/csrc/gptq/util.cuh
new file mode 100644
index 000000000000..7b397573214b
--- /dev/null
+++ b/colossalai/kernel/cuda_native/csrc/gptq/util.cuh
@@ -0,0 +1,33 @@
+// Adapted from turboderp exllama: https://github.com/turboderp/exllama
+
+#ifndef _util_cuh
+#define _util_cuh
+
+#include
+#include
+#include
+#include
+
+#if defined(USE_ROCM)
+#define cudaUnspecified hipErrorUnknown
+#else
+#define cudaUnspecified cudaErrorApiFailureBase
+#endif
+
+// React to failure on return code != cudaSuccess
+
+#define _cuda_check(fn) \
+do { \
+ {_cuda_err = fn;} \
+ if (_cuda_err != cudaSuccess) goto _cuda_fail; \
+} while(false)
+
+// React to failure on return code == 0
+
+#define _alloc_check(fn) \
+do { \
+ if (!(fn)) { _cuda_err = cudaUnspecified; goto _cuda_fail; } \
+ else _cuda_err = cudaSuccess; \
+} while(false)
+
+#endif
diff --git a/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu b/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu
index 26efa2ad6f31..9a6a8ebc3983 100644
--- a/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu
+++ b/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu
@@ -1,7 +1,6 @@
#include
#include
-
#include "cuda_util.h"
/* GPU function guard */
diff --git a/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu
index a39a6dae0f7f..ce0b017f12e1 100644
--- a/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu
+++ b/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu
@@ -1,1002 +1,1002 @@
-#include
-#include
-
-#include "kernels.h"
-
-#include
-
-
-namespace cg = cooperative_groups;
-
-curandStatePhilox4_32_10_t *curandstate;
-
-/**
- * @brief element-wise activation function on device, like Relu, Gelu
- *
- * @tparam enum class ActivationType, kRelu, kGelu
- * @tparam input type
- * @param any shape of float and __half2
- * @return same shape and type with input
- */
-template
-__forceinline__ __device__ T activation_kernel(T x);
-
-template <>
-__device__ float activation_kernel(float x) {
- float cdf =
- 0.5f *
- (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x))));
- return x * cdf;
-}
-
-template <>
-__device__ __half2
-activation_kernel(__half2 val) {
- __half2 val_pow3 = __hmul2(val, __hmul2(val, val));
- float2 tmp_pow = __half22float2(val_pow3);
- float2 tmp = __half22float2(val);
-
- tmp.x =
- 0.5f *
- (1.0f + tanhf((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x))));
- tmp.y =
- 0.5f *
- (1.0f + tanhf((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y))));
- return __hmul2(val, __float22half2_rn(tmp));
-}
-
-template <>
-__device__ float activation_kernel(float x) {
- return fmaxf(x, 0);
-}
-
-template <>
-__device__ __half2
-activation_kernel(__half2 x) {
- return __floats2half2_rn(fmaxf(0.f, __half2float(x.x)),
- fmaxf(0.f, __half2float(x.y)));
-}
-
-/**
- * @brief element-wise activation backward function on device
- *
- * @tparam enum class ActivationType
- * @tparam input type
- * @param any shape of float and __half2
- * @return same shape of input
- */
-template
-__forceinline__ __device__ T activation_bwd_kernel(T grad, T x);
-
-template <>
-__device__ float activation_bwd_kernel(float grad,
- float x) {
- const float sqrt_param = 0.79788456080286535587989211986876f;
- const float mul_param = 0.044715;
-
- float x2mul = x * x * mul_param;
- float tan_h = tanhf(sqrt_param * (x + x * x2mul));
- float dg1 = 0.5f * (1.0f + tan_h);
- float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h);
- float dg3 = dg2 * 3 * x2mul;
- return grad * (dg1 + dg2 + dg3);
-}
-
-template <>
-__device__ __half activation_bwd_kernel(
- __half grad, __half x_half) {
- float x = __half2float(x_half);
- const float sqrt_param = 0.79788456080286535587989211986876f;
- const float mul_param = 0.044715;
-
- float x2mul = x * x * mul_param;
- float tan_h = tanhf(sqrt_param * (x + x * x2mul));
- float dg1 = 0.5f * (1.0f + tan_h);
- float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h);
- float dg3 = dg2 * 3 * x2mul;
- return grad * __float2half(dg1 + dg2 + dg3);
-}
-
-template <>
-__device__ float activation_bwd_kernel(float grad,
- float x) {
- return x > 0.f ? grad : 0.f;
-}
-
-template <>
-__device__ __half
-activation_bwd_kernel(__half grad, __half x) {
- const __half half_zero = __float2half(0.f);
- return x > half_zero ? grad : half_zero;
-}
-
-template <>
-__device__ __half2 activation_bwd_kernel(
- __half2 grad2, __half2 x_half2) {
- const __half half_zero = __float2half(0.f);
- return __floats2half2_rn(x_half2.x > half_zero ? grad2.x : half_zero,
- x_half2.y > half_zero ? grad2.y : half_zero);
-}
-
-/**
- * @brief init curand states in global memory
- *
- * @thread grid_dim * block*dim to suuport any size of states
- * @param state persistant curand states
- * @param seed seed to init states
- * @return void
- */
-__global__ void curand_init_kernel(curandStatePhilox4_32_10_t *state,
- int seed) {
- /* Each thread gets same seed, a different sequence
- number, no offset */
- int id = threadIdx.x + blockIdx.x * blockDim.x;
- curand_init(seed, id, 0, &state[id]);
-}
-
-void launch_curand_init(int total_count, int dim, cudaStream_t stream) {
- cudaMalloc(&curandstate, total_count * sizeof(curandStatePhilox4_32_10_t));
- int grid_dim = total_count >> 9;
- curand_init_kernel<<>>(
- curandstate, std::chrono::duration_cast(
- std::chrono::system_clock::now().time_since_epoch())
- .count());
-}
-
-/**
- * @brief element-wise dropout, store dropped position in mask, it's not
- * in-place
- *
- * @thread
- * gridDim.x = total_count / 1024
- * blockDim.x = 1024
- *
- * @param total_count total elements
- * @param ratio drop ratio
- * @param out any size of float and __half
- * @param in same with out
- * @param mask uint8 type, same size with out
- * @param seed seed to curand
- * @return void
- */
-__global__ void ls_dropout_kernel(const int total_count, const float ratio,
- float *__restrict__ out,
- const float *__restrict__ in,
- uint8_t *__restrict__ mask, const int seed) {
- const float scale = 1.f / (1.f - ratio);
- int i = blockIdx.x * blockDim.x + threadIdx.x;
-
- if (i * 4 >= total_count) return;
-
- curandStatePhilox4_32_10_t state;
- curand_init(seed, i, 0, &state);
- uint8_t m[4];
-
- float4 *out4 = reinterpret_cast(out);
- const float4 *data4 = reinterpret_cast(in);
- uint32_t *mask4 = reinterpret_cast(mask);
- float4 rand = curand_uniform4(&state);
-
- m[0] = (uint8_t)(rand.x > ratio);
- m[1] = (uint8_t)(rand.y > ratio);
- m[2] = (uint8_t)(rand.z > ratio);
- m[3] = (uint8_t)(rand.w > ratio);
-
- uint32_t *m4 = reinterpret_cast(m);
- mask4[i] = m4[0];
-
- float4 input4 = data4[i];
- float4 res4;
- res4.x = input4.x * scale * m[0];
- res4.y = input4.y * scale * m[1];
- res4.z = input4.z * scale * m[2];
- res4.w = input4.w * scale * m[3];
- out4[i] = res4;
-}
-
-__global__ void ls_dropout_kernel(const int total_count, const float ratio,
- __half *__restrict__ out,
- const __half *__restrict__ in,
- uint8_t *__restrict__ mask, const int seed) {
- const float scale = 1.f / (1.f - ratio);
-
- int i = blockIdx.x * blockDim.x + threadIdx.x;
-
- if (i * 8 >= total_count) return;
-
- curandStatePhilox4_32_10_t state;
- curand_init(seed, i, 0, &state);
-
- const float4 *vals_float4 = reinterpret_cast(in);
- float4 *outs_float4 = reinterpret_cast(out);
- uint64_t *mask8 = reinterpret_cast(mask);
-
- uint8_t m[8];
- float4 rand = curand_uniform4(&state);
- m[0] = (uint8_t)(rand.x > ratio);
- m[1] = (uint8_t)(rand.y > ratio);
- m[2] = (uint8_t)(rand.z > ratio);
- m[3] = (uint8_t)(rand.w > ratio);
- rand = curand_uniform4(&state);
- m[4] = (uint8_t)(rand.x > ratio);
- m[5] = (uint8_t)(rand.y > ratio);
- m[6] = (uint8_t)(rand.z > ratio);
- m[7] = (uint8_t)(rand.w > ratio);
- uint64_t *m8 = reinterpret_cast(m);
- mask8[i] = *m8;
-
- float4 val_float4 = vals_float4[i];
- float4 out_float4;
- __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4);
- __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4);
- __half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]);
- __half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]);
- __half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]);
- __half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]);
- out_half2[0] = __hmul2(val_half2[0], scale_mask_1);
- out_half2[1] = __hmul2(val_half2[1], scale_mask_2);
- out_half2[2] = __hmul2(val_half2[2], scale_mask_3);
- out_half2[3] = __hmul2(val_half2[3], scale_mask_4);
- outs_float4[i] = out_float4;
-}
-
-/**
- * @brief element-wise dropout backward with dropout mask, it's
- * not in-place
- *
- * @thread
- * gridDim.x = total_count / 1024
- * blockDim.x = 1024
- *
- * @param total_count total elements
- * @param ratio drop ratio
- * @param in any size of float and __half
- * @param mask uint8 type, same size with in
- * @return void
- */
-__global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio,
- float *out, const float *in,
- const uint8_t *__restrict__ mask) {
- const float scale = 1.f / (1.f - ratio);
- int i = blockIdx.x * blockDim.x + threadIdx.x;
-
- if (i * 4 >= total_count) return;
-
- uint8_t m[4];
-
- float4 *out4 = reinterpret_cast(out);
- const float4 *in4 = reinterpret_cast(in);
- const uint32_t *mask4 = reinterpret_cast(mask);
-
- uint32_t *m4 = reinterpret_cast(m);
- m4[0] = mask4[i];
-
- float4 input4 = in4[i];
- float4 res4;
- res4.x = input4.x * scale * static_cast(m[0]);
- res4.y = input4.y * scale * static_cast(m[1]);
- res4.z = input4.z * scale * static_cast(m[2]);
- res4.w = input4.w * scale * static_cast(m[3]);
- out4[i] = res4;
-}
-
-__global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio,
- __half *out, const __half *in,
- const uint8_t *__restrict__ mask) {
- const __half scale = 1.f / (1.f - ratio);
-
- int i = blockIdx.x * blockDim.x + threadIdx.x;
-
- if (i * 8 >= total_count) return;
-
- float4 *out4 = reinterpret_cast(out);
- const float4 *vals_float4 = reinterpret_cast(in);
- const uint64_t *mask8 = reinterpret_cast(mask);
-
- uint8_t m[8];
- uint64_t *m8 = reinterpret_cast(m);
- m8[0] = mask8[i];
-
- float4 val_float4 = vals_float4[i];
- float4 out_float4;
- __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4);
- __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4);
- __half2 scale_mask_1 =
- __halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1]));
- __half2 scale_mask_2 =
- __halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3]));
- __half2 scale_mask_3 =
- __halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5]));
- __half2 scale_mask_4 =
- __halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7]));
- out_half2[0] = __hmul2(val_half2[0], scale_mask_1);
- out_half2[1] = __hmul2(val_half2[1], scale_mask_2);
- out_half2[2] = __hmul2(val_half2[2], scale_mask_3);
- out_half2[3] = __hmul2(val_half2[3], scale_mask_4);
- out4[i] = out_float4;
-}
-
-template <>
-void launch_ls_dropout(float *out, const float *vals, uint8_t *mask,
- int total_count, float ratio, cudaStream_t stream,
- bool backward) {
- int grid_dim = total_count >> 12;
- if (!backward) {
- ls_dropout_kernel<<>>(
- total_count, ratio, out, vals, mask,
- std::chrono::duration_cast(
- std::chrono::system_clock::now().time_since_epoch())
- .count());
- } else {
- ls_dropout_bwd_kernel<<>>(total_count, ratio,
- out, vals, mask);
- }
-}
-
-template <>
-void launch_ls_dropout<__half>(__half *out, const __half *vals, uint8_t *mask,
- int total_count, float ratio,
- cudaStream_t stream, bool backward) {
- int grid_dim = total_count >> 13;
- if (!backward) {
- ls_dropout_kernel<<>>(
- total_count, ratio, out, vals, mask,
- std::chrono::duration_cast(
- std::chrono::system_clock::now().time_since_epoch())
- .count());
- } else {
- ls_dropout_bwd_kernel<<>>(total_count, ratio,
- out, vals, mask);
- }
-}
-
-/**
- * @brief fused bias, dropout, and residual at the end of Attention and FFN,
- * store dropped position in mask, it's not in-place
- *
- * @thread
- * gridDim.x = total_count / 1024
- * blockDim.x = 1024
- *
- * @param total_count total elements
- * @param ratio drop ratio
- * @param out [batch_size, seq_len, hidden_size], float and __half
- * @param in [batch_size, seq_len, hidden_size], float and __half
- * @param mask [batch_size, seq_len, hidden_size], uint8 type
- * @param bias [hidden_size], ffn bias
- * @param residual [batch_size, seq_len, hidden_size], float and __half
- * @param seed seed to curand
- * @param hidden_size hidden size
- * @return void
- */
-__global__ void ls_dropout_res_bias_kernel(
- const int total_count, const float ratio, float *__restrict__ out,
- const float *__restrict__ in, uint8_t *__restrict__ mask,
- const float *__restrict__ bias, const float *__restrict__ residual,
- const int seed, const int hidden_size) {
- const float scale = 1.f / (1.f - ratio);
- int i = blockIdx.x * blockDim.x + threadIdx.x;
-
- if (i * 4 >= total_count) return;
-
- curandStatePhilox4_32_10_t state;
- curand_init(seed, i, 0, &state);
- uint8_t m[4];
-
- float4 *out4 = reinterpret_cast(out);
- const float4 *data4 = reinterpret_cast(in);
- const float4 *residual4 = reinterpret_cast(residual);
- const float4 *bias4 = reinterpret_cast(bias);
- uint32_t *mask4 = reinterpret_cast(mask);
- float4 rand = curand_uniform4(&state);
-
- m[0] = static_cast(rand.x > ratio);
- m[1] = static_cast(rand.y > ratio);
- m[2] = static_cast(rand.z > ratio);
- m[3] = static_cast(rand.w > ratio);
-
- int bias_i = i % (hidden_size >> 2);
- uint32_t *m4 = reinterpret_cast(m);
- mask4[i] = m4[0];
- const float4 input4 = data4[i];
- const float4 b4 = __ldg(&bias4[bias_i]);
- const float4 res4 = residual4[i];
- float4 output4;
-
- output4.x = (input4.x + b4.x) * scale * m[0] + res4.x;
- output4.y = (input4.y + b4.y) * scale * m[1] + res4.y;
- output4.z = (input4.z + b4.z) * scale * m[2] + res4.z;
- output4.w = (input4.w + b4.w) * scale * m[3] + res4.w;
-
- out4[i] = output4;
-}
-
-__global__ void ls_dropout_res_bias_kernel(
- const int total_count, const float ratio, __half *__restrict__ out,
- const __half *__restrict__ in, uint8_t *__restrict__ mask,
- const __half *__restrict__ bias, const __half *__restrict__ residual,
- const int seed, const int hidden_size) {
- const __half scale = 1. / (1. - ratio);
-
- int i = blockIdx.x * blockDim.x + threadIdx.x;
-
- if (i * 8 >= total_count) return;
-
- curandStatePhilox4_32_10_t state;
- curand_init(seed, i, 0, &state);
-
- const float4 *vals_float4 = reinterpret_cast(in);
- float4 *outs_float4 = reinterpret_cast(out);
- const float4 *residual4 = reinterpret_cast(residual);
- const float4 *bias4 = reinterpret_cast(bias);
- uint64_t *mask8 = reinterpret_cast(mask);
-
- uint8_t m[8];
- float4 rand = curand_uniform4(&state);
- m[0] = static_cast(rand.x > ratio);
- m[1] = static_cast(rand.y > ratio);
- m[2] = static_cast(rand.z > ratio);
- m[3] = static_cast(rand.w > ratio);
- rand = curand_uniform4(&state);
- m[4] = static_cast(rand.x > ratio);
- m[5] = static_cast(rand.y > ratio);
- m[6] = static_cast(rand.z > ratio);
- m[7] = static_cast(rand.w > ratio);
- uint64_t *m8 = reinterpret_cast(m);
- mask8[i] = m8[0];
-
- int bias_i = i % (hidden_size >> 3);
- float4 val_float4 = vals_float4[i];
- const float4 b4 = __ldg(&bias4[bias_i]);
- const float4 res4 = residual4[i];
- float4 out_float4;
-
- __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4);
- __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4);
- const __half2 *b_half2 = reinterpret_cast(&b4);
- const __half2 *res_half2 = reinterpret_cast(&res4);
- __half2 scale_mask_1 =
- __halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1]));
- __half2 scale_mask_2 =
- __halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3]));
- __half2 scale_mask_3 =
- __halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5]));
- __half2 scale_mask_4 =
- __halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7]));
- out_half2[0] =
- __hfma2(__hadd2(val_half2[0], b_half2[0]), scale_mask_1, res_half2[0]);
- out_half2[1] =
- __hfma2(__hadd2(val_half2[1], b_half2[1]), scale_mask_2, res_half2[1]);
- out_half2[2] =
- __hfma2(__hadd2(val_half2[2], b_half2[2]), scale_mask_3, res_half2[2]);
- out_half2[3] =
- __hfma2(__hadd2(val_half2[3], b_half2[3]), scale_mask_4, res_half2[3]);
- outs_float4[i] = out_float4;
-}
-
-template <>
-void launch_ls_dropout_res_bias(float *out, const float *vals,
- uint8_t *mask, const float *bias,
- const float *residual, int total_count,
- int dim, float ratio,
- cudaStream_t stream) {
- int grid_dim = total_count >> 12;
- ls_dropout_res_bias_kernel<<>>(
- total_count, ratio, out, vals, mask, bias, residual,
- std::chrono::duration_cast(
- std::chrono::system_clock::now().time_since_epoch())
- .count(),
- dim);
-}
-
-template <>
-void launch_ls_dropout_res_bias<__half>(__half *out, const __half *vals,
- uint8_t *mask, const __half *bias,
- const __half *residual, int total_count,
- int dim, float ratio,
- cudaStream_t stream) {
- int grid_dim = total_count >> 13;
- ls_dropout_res_bias_kernel<<>>(
- total_count, ratio, out, vals, mask, bias, residual,
- std::chrono::duration_cast(
- std::chrono::system_clock::now().time_since_epoch())
- .count(),
- dim);
-}
-
-/**
- * @brief fused bias and dropout backward at the end of Attention and FFN
- *
- * @thread
- * gridDim.x = hidden_size / 8
- * blockDim.x = 8
- * blockDim.y = 1024 / 8 = 128
- *
- * @param row_size batch_size * seq_len
- * @param ratio dropout ratio
- * @param in_grad [batch_size, seq_len, hidden_size], input grad
- * @param bias_grad [hidden_size], bias grad
- * @param out_grad [batch_size, seq_len, hidden_size], output grad
- * @param mask [batch_size, seq_len, hidden_size], dropout mask
- * @param hidden_size
- * @return void
- */
-__global__ void ls_dropout_bias_bwd_kernel(
- const int row_size, const float ratio, float *__restrict__ in_grad,
- float *__restrict__ bias_grad, const float *__restrict__ out_grad,
- const uint8_t *__restrict__ mask, const int hidden_size) {
- const float scale = 1.f / (1.f - ratio);
- // every block generate 8 bias result
- __shared__ float tile[8][129];
-
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile g = cg::tiled_partition(b);
-
- int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8);
- int stride = hidden_size * 128;
- float local_sum = 0;
-
- int idx = flat_2dim(threadIdx.y, col_idx, hidden_size);
- for (int r = threadIdx.y; r < row_size; r += 128) {
- float val = out_grad[idx];
- val *= scale * static_cast(mask[idx]);
- local_sum += val;
- in_grad[idx] = val;
- idx += stride;
- }
-
- tile[threadIdx.x][threadIdx.y] = local_sum;
- __syncthreads();
-
- float sum = 0;
- int tid = threadIdx.y * blockDim.x + threadIdx.x;
- int x = tid >> 7;
- int y = tid & (127);
- if (y < 32) {
-#pragma unroll
- for (int i = 0; i < 4; i++) {
- sum += tile[x][y + i * 32];
- }
- }
- __syncthreads();
-
- for (int i = 1; i < 32; i <<= 1) sum += g.shfl_down(sum, i);
-
- if (y == 0) tile[0][x] = sum;
- __syncthreads();
-
- if (threadIdx.x < 8) {
- int pos = flat_2dim(blockIdx.x, threadIdx.x, 8);
- bias_grad[pos] = tile[0][threadIdx.x];
- }
-}
-
-__global__ void ls_dropout_bias_bwd_kernel(
- const int row_size, const float ratio, __half *__restrict__ in_grad,
- __half *__restrict__ bias_grad, const __half *__restrict__ out_grad,
- const uint8_t *__restrict__ mask, const int hidden_size) {
- const __half2 scale = __float2half2_rn(1.f / (1.f - ratio));
- __shared__ __half2 tile[8][129];
-
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile g = cg::tiled_partition(b);
-
- __half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad);
- const __half2 *out_grad2 = reinterpret_cast(out_grad);
- __half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad);
-
- int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8);
- int stride = hidden_size * 128;
- __half2 local_sum = __float2half2_rn(0.f);
-
- int idx = flat_2dim(threadIdx.y, col_idx, hidden_size);
- for (int r = threadIdx.y; r < row_size; r += 128) {
- __half2 val = out_grad2[idx];
- __half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]);
- val *= scale * m2;
- local_sum += val;
- in_grad2[idx] = val;
- idx += stride;
- }
-
- tile[threadIdx.x][threadIdx.y] = local_sum;
- __syncthreads();
-
- __half2 sum = __float2half2_rn(0.f);
- int tid = threadIdx.y * blockDim.x + threadIdx.x;
- int x = tid >> 7;
- int y = tid & (127);
- if (y < 32) {
-#pragma unroll
- for (int i = 0; i < 4; i++) {
- sum += tile[x][y + i * 32];
- }
- }
- __syncthreads();
-
- for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i);
-
- if (y == 0) tile[0][x] = sum;
- __syncthreads();
-
- if (threadIdx.x < 8) {
- int pos = flat_2dim(blockIdx.x, threadIdx.x, 8);
- bias_grad2[pos] = tile[0][threadIdx.x];
- }
-}
-
-template
-void launch_ls_dropout_bias_bwd(T *in_grad, T *bias_grad, const T *out_grad,
- const uint8_t *mask, int row_size, int dim,
- float ratio, cudaStream_t stream) {
- dim3 grid_dim((dim - 1) / 8 + 1);
- dim3 block_dim(8, 128);
- ls_dropout_bias_bwd_kernel<<>>(
- row_size, ratio, in_grad, bias_grad, out_grad, mask, dim);
-}
-
-template <>
-void launch_ls_dropout_bias_bwd(__half *in_grad, __half *bias_grad,
- const __half *out_grad, const uint8_t *mask,
- int row_size, int dim, float ratio,
- cudaStream_t stream) {
- dim >>= 1;
- dim3 grid_dim((dim - 1) / 8 + 1);
- dim3 block_dim(8, 128);
- ls_dropout_bias_bwd_kernel<<>>(
- row_size, ratio, in_grad, bias_grad, out_grad, mask, dim);
-}
-
-template void launch_ls_dropout_bias_bwd(float *in_grad, float *bias_grad,
- const float *out_grad,
- const uint8_t *mask, int row_size,
- int dim, float ratio,
- cudaStream_t stream);
-
-/**
- * @brief fused bias, activation, and dropout at the end of first ffn
- *
- * @thread
- * gridDim.x = hidden_size / 8
- * blockDim.x = 8
- * blockDim.y = 1024 / 8 = 128
- *
- * @tparam act_type activation function, like kRelu, kGelu
- * @param total_count total elements
- * @param ratio drop ratio
- * @param out [batch_size, seq_len, hidden_size], float and __half
- * @param in [batch_size, seq_len, hidden_size], float and __half
- * @param mask [batch_size, seq_len, hidden_size], uint8 type
- * @param bias [hidden_size], ffn bias
- * @param seed seed to curand
- * @param hidden_size
- * @return void
- */
-template
-__global__ void ls_dropout_act_bias_kernel(
- const int total_count, const float ratio, float *__restrict__ out,
- const float *__restrict__ in, uint8_t *__restrict__ mask,
- const float *__restrict__ bias, const int seed, const int hidden_size) {
- const float scale = 1.f / (1.f - ratio);
- int i = blockIdx.x * blockDim.x + threadIdx.x;
-
- if (i * 4 >= total_count) return;
-
- curandStatePhilox4_32_10_t state;
- curand_init(seed, i, 0, &state);
- uint8_t m[4];
-
- float4 *out4 = reinterpret_cast(out);
- const float4 *data4 = reinterpret_cast(in);
- const float4 *bias4 = reinterpret_cast(bias);
- uint32_t *mask4 = reinterpret_cast(mask);
- float4 rand = curand_uniform4(&state);
-
- m[0] = (uint8_t)(rand.x > ratio);
- m[1] = (uint8_t)(rand.y > ratio);
- m[2] = (uint8_t)(rand.z > ratio);
- m[3] = (uint8_t)(rand.w > ratio);
-
- int bias_i = i % (hidden_size >> 2);
- uint32_t *m4 = reinterpret_cast(m);
- mask4[i] = m4[0];
- const float4 input4 = data4[i];
- const float4 b4 = __ldg(&bias4[bias_i]);
- float4 output4;
-
- output4.x =
- activation_kernel(input4.x + b4.x) * scale * m[0];
- output4.y =
- activation_kernel(input4.y + b4.y) * scale * m[1];
- output4.z =
- activation_kernel(input4.z + b4.z) * scale * m[2];
- output4.w =
- activation_kernel(input4.w + b4.w) * scale * m[3];
-
- out4[i] = output4;
-}
-
-template
-__global__ void ls_dropout_act_bias_kernel(
- const int total_count, const float ratio, __half *__restrict__ out,
- const __half *__restrict__ in, uint8_t *__restrict__ mask,
- const __half *__restrict__ bias, const int seed, const int hidden_size) {
- const float scale = 1.f / (1.f - ratio);
-
- int i = blockIdx.x * blockDim.x + threadIdx.x;
-
- if (i * 8 >= total_count) return;
-
- curandStatePhilox4_32_10_t state;
- curand_init(seed, i, 0, &state);
-
- const float4 *vals_float4 = reinterpret_cast(in);
- float4 *outs_float4 = reinterpret_cast(out);
- const float4 *bias4 = reinterpret_cast(bias);
- uint64_t *mask8 = reinterpret_cast(mask);
-
- uint8_t m[8];
- float4 rand = curand_uniform4(&state);
- m[0] = (uint8_t)(rand.x > ratio);
- m[1] = (uint8_t)(rand.y > ratio);
- m[2] = (uint8_t)(rand.z > ratio);
- m[3] = (uint8_t)(rand.w > ratio);
- rand = curand_uniform4(&state);
- m[4] = (uint8_t)(rand.x > ratio);
- m[5] = (uint8_t)(rand.y > ratio);
- m[6] = (uint8_t)(rand.z > ratio);
- m[7] = (uint8_t)(rand.w > ratio);
- uint64_t *m8 = reinterpret_cast(m);
- mask8[i] = *m8;
-
- int bias_i = i % (hidden_size >> 3);
- float4 val_float4 = vals_float4[i];
- const float4 b4 = __ldg(&bias4[bias_i]);
- float4 out_float4;
-
- __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4);
- __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4);
- const __half2 *b_half2 = reinterpret_cast(&b4);
-
- __half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]);
- __half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]);
- __half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]);
- __half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]);
- out_half2[0] = __hmul2(
- activation_kernel(__hadd2(val_half2[0], b_half2[0])),
- scale_mask_1);
- out_half2[1] = __hmul2(
- activation_kernel(__hadd2(val_half2[1], b_half2[1])),
- scale_mask_2);
- out_half2[2] = __hmul2(
- activation_kernel(__hadd2(val_half2[2], b_half2[2])),
- scale_mask_3);
- out_half2[3] = __hmul2(
- activation_kernel(__hadd2(val_half2[3], b_half2[3])),
- scale_mask_4);
- outs_float4[i] = out_float4;
-}
-
-template <>
-void launch_ls_dropout_act_bias(
- float *out, const float *vals, uint8_t *mask, const float *bias,
- int total_count, int dim, float ratio, cudaStream_t stream) {
- int grid_dim = total_count >> 10;
- ls_dropout_act_bias_kernel
- <<>>(
- total_count, ratio, out, vals, mask, bias,
- std::chrono::duration_cast(
- std::chrono::system_clock::now().time_since_epoch())
- .count(),
- dim);
-}
-
-template <>
-void launch_ls_dropout_act_bias(
- __half *out, const __half *vals, uint8_t *mask, const __half *bias,
- int total_count, int dim, float ratio, cudaStream_t stream) {
- int grid_dim = total_count >> 11;
- ls_dropout_act_bias_kernel
- <<>>(
- total_count, ratio, out, vals, mask, bias,
- std::chrono::duration_cast(
- std::chrono::system_clock::now().time_since_epoch())
- .count(),
- dim);
-}
-
-template <>
-void launch_ls_dropout_act_bias(
- float *out, const float *vals, uint8_t *mask, const float *bias,
- int total_count, int dim, float ratio, cudaStream_t stream) {
- int grid_dim = total_count >> 10;
- ls_dropout_act_bias_kernel
- <<>>(
- total_count, ratio, out, vals, mask, bias,
- std::chrono::duration_cast(
- std::chrono::system_clock::now().time_since_epoch())
- .count(),
- dim);
-}
-
-template <>
-void launch_ls_dropout_act_bias(
- __half *out, const __half *vals, uint8_t *mask, const __half *bias,
- int total_count, int dim, float ratio, cudaStream_t stream) {
- int grid_dim = total_count >> 11;
- ls_dropout_act_bias_kernel
- <<>>(
- total_count, ratio, out, vals, mask, bias,
- std::chrono::duration_cast(
- std::chrono::system_clock::now().time_since_epoch())
- .count(),
- dim);
-}
-
-/**
- * @brief fused bias, activation, and dropout backward
- *
- * @thread
- * gridDim.x = total_count / 1024
- * blockDim.x = 1024
- *
- * @tparam act_type kRelu
- * @param row_size batch_size * seq_len
- * @param ratio dropout ratio
- * @param in_grad [batch_size, seq_len, hidden_size], input grad
- * @param bias_grad [hidden_size], bias grad
- * @param out_grad [batch_size, seq_len, hidden_size], output grad
- * @param mask [batch_size, seq_len, hidden_size], dropout mask
- * @param hidden_size
- * @return void
- */
-template
-__global__ void ls_dropout_act_bias_bwd_kernel(
- const int row_size, const float ratio, T *in_grad,
- T *__restrict__ bias_grad, const T *__restrict__ input,
- const T *__restrict__ bias, const T *out_grad,
- const uint8_t *__restrict__ mask, const int hidden_size) {
- const float scale = 1.f / (1.f - ratio);
- __shared__ float tile[WARP_SIZE][WARP_SIZE + 1];
-
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile g = cg::tiled_partition(b);
-
- int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE);
-
- int stride = hidden_size * WARP_SIZE;
- float local_sum = 0;
-
- int idx = flat_2dim(threadIdx.y, col_idx, hidden_size);
- if (col_idx < hidden_size) {
- for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) {
- float val = out_grad[idx];
- float in = input[idx];
- float b = bias[idx % hidden_size];
- val = activation_bwd_kernel(
- val * scale * static_cast(mask[idx]), in + b);
- local_sum += val;
- in_grad[idx] = val;
- idx += stride;
- }
- }
-
- tile[threadIdx.x][threadIdx.y] = local_sum;
- __syncthreads();
- float sum = tile[threadIdx.y][threadIdx.x];
- __syncthreads();
-
- for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i);
-
- if (threadIdx.x == 0) tile[0][threadIdx.y] = sum;
- __syncthreads();
-
- if (threadIdx.y == 0) {
- int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE);
- bias_grad[pos] = tile[0][threadIdx.x];
- }
-}
-
-// @brief fused bias, activation, and dropout backward
-// It is deprecated for precision reason. Keep it for future optimization.
-//
-// template
-// __global__ void ls_dropout_act_bias_bwd_kernel(
-// const int row_size, const float ratio, __half * in_grad,
-// __half *__restrict__ bias_grad, const __half *__restrict__ input, const
-// __half *__restrict__ bias, const __half * out_grad, const uint8_t
-// *__restrict__ mask, const int hidden_size) {
-// const __half2 scale = __float2half2_rn(1.f / (1.f - ratio));
-// __shared__ __half2 tile[WARP_SIZE][WARP_SIZE + 1];
-
-// cg::thread_block b = cg::this_thread_block();
-// cg::thread_block_tile g = cg::tiled_partition(b);
-
-// __half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad);
-// __half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad);
-// const __half2 *out_grad2 = reinterpret_cast(out_grad);
-// const __half2 *input2 = reinterpret_cast(input);
-// const __half2 *bias2 = reinterpret_cast(bias);
-
-// int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE);
-
-// int stride = hidden_size * WARP_SIZE;
-// __half2 local_sum = __float2half2_rn(0.f);
-
-// int idx = flat_2dim(threadIdx.y, col_idx, hidden_size);
-// if (col_idx < hidden_size) {
-// for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) {
-// __half2 val = out_grad2[idx];
-// __half2 in2 = input2[idx];
-// __half2 b2 = bias2[idx % hidden_size ];
-// __half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]);
-// val = activation_bwd_kernel(val * scale
-// *
-// m2,
-// in2+b2);
-// local_sum += val;
-// in_grad2[idx] = val;
-// idx += stride;
-// }
-// }
-
-// tile[threadIdx.x][threadIdx.y] = local_sum;
-// __syncthreads();
-// __half2 sum = tile[threadIdx.y][threadIdx.x];
-// __syncthreads();
-
-// for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i);
-
-// if (threadIdx.x == 0) tile[0][threadIdx.y] = sum;
-// __syncthreads();
-
-// if (threadIdx.y == 0) {
-// int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE);
-// bias_grad2[pos] = tile[0][threadIdx.x];
-// }
-// }
-
-template
-void launch_ls_dropout_act_bias_bwd(T *in_grad, T *bias_grad, const T *input,
- const T *bias, const T *out_grad,
- const uint8_t *mask, int row_size, int dim,
- float ratio, cudaStream_t stream) {
- dim3 grid_dim((dim - 1) / WARP_SIZE + 1);
- dim3 block_dim(WARP_SIZE, WARP_SIZE);
- ls_dropout_act_bias_bwd_kernel<<>>(
- row_size, ratio, in_grad, bias_grad, input, bias, out_grad, mask, dim);
-}
-
-// template <>
-// void launch_ls_dropout_act_bias_bwd(
-// __half *in_grad, __half *bias_grad,const __half *input, const __half
-// *bias, const __half *out_grad, const uint8_t *mask, int row_size, int
-// dim, float ratio, cudaStream_t stream) {
-// dim >>= 1;
-// dim3 grid_dim((dim - 1) / WARP_SIZE + 1);
-// dim3 block_dim(WARP_SIZE, WARP_SIZE);
-// ls_dropout_act_bias_bwd_kernel
-// <<>>(row_size, ratio, in_grad,
-// bias_grad,
-// input, bias,out_grad, mask, dim);
-// }
-
-template void launch_ls_dropout_act_bias_bwd(
- float *in_grad, float *bias_grad, const float *input, const float *bias,
- const float *out_grad, const uint8_t *mask, int row_size, int dim,
- float ratio, cudaStream_t stream);
-
-template void launch_ls_dropout_act_bias_bwd(
- __half *in_grad, __half *bias_grad, const __half *input, const __half *bias,
- const __half *out_grad, const uint8_t *mask, int row_size, int dim,
- float ratio, cudaStream_t stream);
-
-template void launch_ls_dropout_act_bias_bwd(
- float *in_grad, float *bias_grad, const float *input, const float *bias,
- const float *out_grad, const uint8_t *mask, int row_size, int dim,
- float ratio, cudaStream_t stream);
-
-template void launch_ls_dropout_act_bias_bwd(
- __half *in_grad, __half *bias_grad, const __half *input, const __half *bias,
- const __half *out_grad, const uint8_t *mask, int row_size, int dim,
- float ratio, cudaStream_t stream);
+#include
+#include
+
+#include "kernels.h"
+
+#include
+
+
+namespace cg = cooperative_groups;
+
+curandStatePhilox4_32_10_t *curandstate;
+
+/**
+ * @brief element-wise activation function on device, like Relu, Gelu
+ *
+ * @tparam enum class ActivationType, kRelu, kGelu
+ * @tparam input type
+ * @param any shape of float and __half2
+ * @return same shape and type with input
+ */
+template
+__forceinline__ __device__ T activation_kernel(T x);
+
+template <>
+__device__ float activation_kernel(float x) {
+ float cdf =
+ 0.5f *
+ (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x))));
+ return x * cdf;
+}
+
+template <>
+__device__ __half2
+activation_kernel(__half2 val) {
+ __half2 val_pow3 = __hmul2(val, __hmul2(val, val));
+ float2 tmp_pow = __half22float2(val_pow3);
+ float2 tmp = __half22float2(val);
+
+ tmp.x =
+ 0.5f *
+ (1.0f + tanhf((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x))));
+ tmp.y =
+ 0.5f *
+ (1.0f + tanhf((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y))));
+ return __hmul2(val, __float22half2_rn(tmp));
+}
+
+template <>
+__device__ float activation_kernel(float x) {
+ return fmaxf(x, 0);
+}
+
+template <>
+__device__ __half2
+activation_kernel