From 08e29b05a95877b2da939644bf95579925f1085e Mon Sep 17 00:00:00 2001
From: Will Constable <whc@meta.com>
Date: Thu, 23 May 2024 17:02:11 -0700
Subject: [PATCH] Add test for PP tracer frontend

- switch to using public PipelineStage API
- clean up some asserts in tracer codepath

ghstack-source-id: 2d069b7d45c4f3c788dec8fc85d8a7e83e463fcd
Pull Request resolved: https://github.com/pytorch/torchtitan/pull/357
---
 test_runner.py                               | 13 ++++++
 torchtitan/parallelisms/parallelize_llama.py | 46 ++++++++++----------
 2 files changed, 37 insertions(+), 22 deletions(-)

diff --git a/test_runner.py b/test_runner.py
index 834fc080a..59bc49a47 100755
--- a/test_runner.py
+++ b/test_runner.py
@@ -113,6 +113,19 @@ def build_test_list(args):
             "PP+TP 2D test",
             requires_seed_checkpoint=True,
         ),
+        OverrideDefinitions(
+            [
+                [
+                    "--checkpoint.enable_checkpoint",
+                    f"--job.dump_folder {args.output_dir}/pp_tracer/",
+                    "--experimental.pipeline_parallel_degree 2",
+                    "--experimental.pipeline_parallel_split_points layers.1",
+                    "--model.norm_type rmsnorm",  # fused_rmsnorm not yet compatible with tracer
+                ],
+            ],
+            "PP tracer frontend test",
+            requires_seed_checkpoint=True,
+        ),
         OverrideDefinitions(
             [
                 [
diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py
index 425d3abe0..3617eb23c 100644
--- a/torchtitan/parallelisms/parallelize_llama.py
+++ b/torchtitan/parallelisms/parallelize_llama.py
@@ -18,10 +18,11 @@
     checkpoint_wrapper as ptd_checkpoint_wrapper,
     CheckpointImpl,
 )
-from torch.distributed.pipelining import pipeline, SplitPoint
-from torch.distributed.pipelining.PipelineStage import (
-    _PipelineStage,
+from torch.distributed.pipelining import (
     ManualPipelineStage,
+    pipeline,
+    PipelineStage,
+    SplitPoint,
 )
 from torch.distributed.tensor.parallel import (
     ColwiseParallel,
@@ -159,6 +160,14 @@ def _llama_trace_input(job_config, model_config, device="meta"):
     return (tokens,)
 
 
+def _mixed_precision_dtype(
+    job_config: JobConfig, parallel_dims, default: torch.dtype = torch.float32
+) -> torch.dtype:
+    """Get the mixed precision dtype if fsdp is enabled, otherwise return the default"""
+    mp_arg = job_config.training.mixed_precision_param
+    return TORCH_DTYPE_MAP[mp_arg] if parallel_dims.dp_enabled else default
+
+
 def pipeline_llama_manual(
     model, world_mesh, parallel_dims, job_config: JobConfig, device, model_config: Dict
 ):
@@ -204,8 +213,7 @@ def pipeline_llama_manual(
     # TODO(whc) once ManualPipelineStage supports lazy shape inference, we can leave model on meta device longer and
     # get rid of the input shape hardcoded here. For now, it should not be a big deal since we only materialize the
     # layers of the model that map to this stage, not the whole model.
-    mp_arg = job_config.training.mixed_precision_param
-    mp_dtype = TORCH_DTYPE_MAP[mp_arg] if parallel_dims.dp_enabled else torch.float32
+    mp_dtype = _mixed_precision_dtype(job_config, parallel_dims)
     batch_size = job_config.training.batch_size
     local_seq_len = int(job_config.training.seq_len // parallel_dims.tp)
     layers_io_shape = (batch_size, local_seq_len, model_config.dim)
@@ -216,12 +224,7 @@ def pipeline_llama_manual(
     )
     if pp_rank == 0:
         # first layer
-        input = torch.randint(
-            model_config.vocab_size,
-            size=(batch_size, job_config.training.seq_len),
-            dtype=torch.int64,
-            device=device,
-        )
+        (input,) = _llama_trace_input(job_config, model_config, device=device)
     else:
         # later layers (assume all start w/ a transformer layer)
         input = torch.rand(layers_io_shape, dtype=mp_dtype, device=device)
@@ -257,21 +260,21 @@ def pipeline_llama_tracer(
             "fused_rmsnorm not yet compatible with Pipeline Tracer (strides error). Please use layernorm or rmsnorm."
         )
 
-    # TODO(whc) maybe we can just fix this by feeding bf16 into the tracer for its input shapes?
-    raise NotImplementedError(
-        "pipeline tracer doesn't work with fsdp mixed precision currently. "
-        "To work around, edit fsdp mixed precision config to use fp32."
-    )
+    if _mixed_precision_dtype(job_config, parallel_dims) == torch.bfloat16:
+        raise NotImplementedError(
+            "pipeline tracer doesn't work with fsdp mixed precision currently. "
+            "To work around, edit fsdp mixed precision config to use fp32."
+        )
+
     pp_mesh = world_mesh["pp"]
     pp_rank = pp_mesh.get_local_rank()
-    stage_idx = pp_mesh.get_local_rank()
+    stage_idx = pp_rank
     layers_per_rank = len(model.layers) // parallel_dims.pp
     split_spec = {
         f"layers.{i * layers_per_rank}": SplitPoint.BEGINNING
         for i in range(1, parallel_dims.pp)
     }
 
-    # Create a pipeline representation from the model
     pipe = pipeline(
         model,
         job_config.experimental.pipeline_parallel_microbatches or parallel_dims.pp,
@@ -279,10 +282,9 @@ def pipeline_llama_tracer(
         split_spec=split_spec,
     )
     model = pipe.get_stage_module(stage_idx)
-    stage = _PipelineStage(
-        stage_module=model,
-        stage_index=pp_rank,
-        pipe_info=pipe.pipe_info,
+    stage = PipelineStage(
+        pipe,
+        stage_index=stage_idx,
         device=device,
         group=pp_mesh.get_group(),
     )