Skip to content

Commit c285a39

Browse files
author
Yifu Wang
committed
Add the option to turn on async-TP
ghstack-source-id: b55199cfa94d8fc4510d634a2ff06cf9b6ae09f5 Pull Request resolved: #429
1 parent cb73810 commit c285a39

File tree

3 files changed

+17
-0
lines changed

3 files changed

+17
-0
lines changed

torchtitan/config_manager.py

+6
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,12 @@ def __init__(self):
241241
action="store_true",
242242
help="Whether to apply loss parallel when sequence parallel is enabled",
243243
)
244+
self.parser.add_argument(
245+
"--experimental.enable_async_tensor_parallel",
246+
default=False,
247+
action="store_true",
248+
help="Whether to apply async tensor parallel (currently only effective when compile is enabled)",
249+
)
244250
self.parser.add_argument(
245251
"--experimental.pipeline_parallel_degree",
246252
type=int,

torchtitan/parallelisms/parallelize_llama.py

+6
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,12 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
394394
parallelize_plan=layer_plan,
395395
)
396396

397+
if job_config.experimental.enable_async_tensor_parallel:
398+
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
399+
400+
torch._inductor.config._micro_pipeline_tp = True
401+
enable_symm_mem_for_group(tp_mesh.get_group().group_name)
402+
397403
logger.info("Applied Tensor Parallelism to the model")
398404
return model
399405

train.py

+5
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,11 @@ def loss_fn(pred, labels):
252252
for m in model_parts
253253
]
254254

255+
# for ease of testing TP in lieu of FSDP
256+
if job_config.training.tensor_parallel_degree == world_size:
257+
for model in model_parts:
258+
model.to(torch.bfloat16)
259+
255260
init_device = "cpu" if job_config.checkpoint.create_seed_checkpoint else "cuda"
256261
for model in model_parts:
257262
model.to_empty(device=init_device)

0 commit comments

Comments
 (0)