Skip to content

Commit

Permalink
[PP] add flexible interleaved 1f1b schedule
Browse files Browse the repository at this point in the history
ghstack-source-id: 7cb9665c5512800773f08e413367bb4c2a4caa33
Pull Request resolved: #490
  • Loading branch information
H-Huang committed Jul 29, 2024
1 parent 668f6cd commit 69d0cb2
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 1 deletion.
15 changes: 15 additions & 0 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,21 @@ def build_test_list():
requires_seed_checkpoint=True,
ngpu=4,
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 4",
"--experimental.pipeline_parallel_split_points layers.1,layers.2,layers.3,layers.4,layers.5,layers.6,layers.7",
"--experimental.pipeline_parallel_schedule flexible_interleaved_1f1b",
"--model.norm_type rmsnorm", # fused_rmsnorm throws cuda context error with pp
],
],
"PP looped flexible 1f1b test",
"pp_looped_flexible_1f1b",
requires_seed_checkpoint=True,
ngpu=4,
),
OverrideDefinitions(
[
[
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def __init__(self):
self.parser.add_argument(
"--experimental.pipeline_parallel_schedule",
type=str,
choices=["1f1b", "gpipe", "interleaved_1f1b"],
choices=["1f1b", "gpipe", "interleaved_1f1b", "flexible_interleaved_1f1b"],
default="1f1b",
help="""
Specify the Pipeline Parallel schedule to use.
Expand Down
7 changes: 7 additions & 0 deletions torchtitan/parallelisms/pipelining_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from torch.distributed.pipelining import (
Schedule1F1B,
ScheduleFlexibleInterleaved1F1B,
ScheduleGPipe,
ScheduleInterleaved1F1B,
)
Expand All @@ -23,6 +24,12 @@ def build_pipeline_schedule(job_config, parallel_dims, stages, loss_fn):
elif job_config.experimental.pipeline_parallel_schedule == "interleaved_1f1b":
schedule_class = ScheduleInterleaved1F1B
looped_schedule = True
elif (
job_config.experimental.pipeline_parallel_schedule
== "flexible_interleaved_1f1b"
):
schedule_class = ScheduleFlexibleInterleaved1F1B
looped_schedule = True
else:
raise NotImplementedError(
f"{job_config.experimental.pipeline_parallel_schedule} is not implemented"
Expand Down

0 comments on commit 69d0cb2

Please sign in to comment.