Skip to content

Commit 6fde13b

Browse files
author
Yifu Wang
committed
Update on "Add the option to turn on async-TP"
This PR adds the option to turn on async-TP (`--experimental.enable_async_tensor_parallel`). The feature is currently implemented as compiler passes on relevant patterns, so the option is currently only effective when compile is enabled. Some trace samples from llama3_70b with tp degree=8: **all-gather -> qkv projection** Baseline: <img width="420" alt="image" src="https://github.com/pytorch/torchtitan/assets/4156752/df6980c3-4a2f-4455-bdd3-9079b538123f"> With async-TP: <img width="513" alt="image" src="https://github.com/pytorch/torchtitan/assets/4156752/635c3dee-660d-4452-809b-32620343080a"> **ffn -> reduce-scater** Baseline: <img width="537" alt="image" src="https://github.com/pytorch/torchtitan/assets/4156752/6b045c84-48df-4798-a786-4f57e3f4345a"> With async-TP: <img width="451" alt="image" src="https://github.com/pytorch/torchtitan/assets/4156752/63f13859-97f7-48ea-aef6-4e8861b207ac"> **all-gather -> ffn** Baseline: <img width="494" alt="image" src="https://github.com/pytorch/torchtitan/assets/4156752/b1636055-9b5b-43b1-b98e-b91f06af995e"> With async-TP: <img width="536" alt="image" src="https://github.com/pytorch/torchtitan/assets/4156752/3edaedf4-3780-423d-ba86-5aa1cc5e69df"> [ghstack-poisoned]
1 parent 236aa92 commit 6fde13b

File tree

1 file changed

+0
-5
lines changed

1 file changed

+0
-5
lines changed

train.py

-5
Original file line numberDiff line numberDiff line change
@@ -252,11 +252,6 @@ def loss_fn(pred, labels):
252252
for m in model_parts
253253
]
254254

255-
# for ease of testing TP in lieu of FSDP
256-
if job_config.training.tensor_parallel_degree == world_size:
257-
for model in model_parts:
258-
model.to(torch.bfloat16)
259-
260255
init_device = "cpu" if job_config.checkpoint.create_seed_checkpoint else "cuda"
261256
for model in model_parts:
262257
model.to_empty(device=init_device)

0 commit comments

Comments
 (0)