Skip to content

Commit 1ec2ece

Browse files
author
Yifu Wang
committed
Add the option to turn on async-TP
ghstack-source-id: 0a03379eeb3a63b2d1ad4dff84d0e61ca82b1bbf Pull Request resolved: pytorch#429
1 parent f5171cb commit 1ec2ece

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

torchtitan/config_manager.py

+6
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,12 @@ def __init__(self):
241241
action="store_true",
242242
help="Whether to apply loss parallel when sequence parallel is enabled",
243243
)
244+
self.parser.add_argument(
245+
"--experimental.enable_async_tensor_parallel",
246+
default=False,
247+
action="store_true",
248+
help="Whether to apply async tensor parallel (currently only effective when compile is enabled)",
249+
)
244250
self.parser.add_argument(
245251
"--experimental.pipeline_parallel_degree",
246252
type=int,

torchtitan/parallelisms/parallelize_llama.py

+6
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,12 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
394394
parallelize_plan=layer_plan,
395395
)
396396

397+
if job_config.experimental.enable_async_tensor_parallel:
398+
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
399+
400+
torch._inductor.config._micro_pipeline_tp = True
401+
enable_symm_mem_for_group(tp_mesh.get_group().group_name)
402+
397403
logger.info("Applied Tensor Parallelism to the model")
398404
return model
399405

0 commit comments

Comments
 (0)