Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test for PP tracer frontend #357

Merged
merged 2 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,19 @@ def build_test_list(args):
"PP+TP 2D test",
requires_seed_checkpoint=True,
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
f"--job.dump_folder {args.output_dir}/pp_tracer/",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.1",
"--model.norm_type rmsnorm", # fused_rmsnorm not yet compatible with tracer
],
],
"PP tracer frontend test",
requires_seed_checkpoint=True,
),
OverrideDefinitions(
[
[
Expand Down
46 changes: 24 additions & 22 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
checkpoint_wrapper as ptd_checkpoint_wrapper,
CheckpointImpl,
)
from torch.distributed.pipelining import pipeline, SplitPoint
from torch.distributed.pipelining.PipelineStage import (
_PipelineStage,
from torch.distributed.pipelining import (
ManualPipelineStage,
pipeline,
PipelineStage,
SplitPoint,
)
from torch.distributed.tensor.parallel import (
ColwiseParallel,
Expand Down Expand Up @@ -159,6 +160,14 @@ def _llama_trace_input(job_config, model_config, device="meta"):
return (tokens,)


def _mixed_precision_dtype(
job_config: JobConfig, default: torch.dtype = torch.float32
) -> torch.dtype:
"""Get the mixed precision dtype if fsdp is enabled, otherwise return the default"""
mp_arg = job_config.training.mixed_precision_param
return TORCH_DTYPE_MAP[mp_arg] if parallel_dims.dp_enabled else default


def pipeline_llama_manual(
model, world_mesh, parallel_dims, job_config: JobConfig, device, model_config: Dict
):
Expand Down Expand Up @@ -204,8 +213,7 @@ def pipeline_llama_manual(
# TODO(whc) once ManualPipelineStage supports lazy shape inference, we can leave model on meta device longer and
# get rid of the input shape hardcoded here. For now, it should not be a big deal since we only materialize the
# layers of the model that map to this stage, not the whole model.
mp_arg = job_config.training.mixed_precision_param
mp_dtype = TORCH_DTYPE_MAP[mp_arg] if parallel_dims.dp_enabled else torch.float32
mp_dtype = _mixed_precision_dtype(job_config)
batch_size = job_config.training.batch_size
local_seq_len = int(job_config.training.seq_len // parallel_dims.tp)
layers_io_shape = (batch_size, local_seq_len, model_config.dim)
Expand All @@ -216,12 +224,7 @@ def pipeline_llama_manual(
)
if pp_rank == 0:
# first layer
input = torch.randint(
model_config.vocab_size,
size=(batch_size, job_config.training.seq_len),
dtype=torch.int64,
device=device,
)
(input,) = _llama_trace_input(job_config, model_config, device=device)
else:
# later layers (assume all start w/ a transformer layer)
input = torch.rand(layers_io_shape, dtype=mp_dtype, device=device)
Expand Down Expand Up @@ -257,32 +260,31 @@ def pipeline_llama_tracer(
"fused_rmsnorm not yet compatible with Pipeline Tracer (strides error). Please use layernorm or rmsnorm."
)

# TODO(whc) maybe we can just fix this by feeding bf16 into the tracer for its input shapes?
raise NotImplementedError(
"pipeline tracer doesn't work with fsdp mixed precision currently. "
"To work around, edit fsdp mixed precision config to use fp32."
)
if _mixed_precision_dtype(job_config) == torch.bfloat16:
raise NotImplementedError(
"pipeline tracer doesn't work with fsdp mixed precision currently. "
"To work around, edit fsdp mixed precision config to use fp32."
)

pp_mesh = world_mesh["pp"]
pp_rank = pp_mesh.get_local_rank()
stage_idx = pp_mesh.get_local_rank()
stage_idx = pp_rank
layers_per_rank = len(model.layers) // parallel_dims.pp
split_spec = {
f"layers.{i * layers_per_rank}": SplitPoint.BEGINNING
for i in range(1, parallel_dims.pp)
}

# Create a pipeline representation from the model
pipe = pipeline(
model,
job_config.experimental.pipeline_parallel_microbatches or parallel_dims.pp,
example_args=_llama_trace_input(job_config, model_config),
split_spec=split_spec,
)
model = pipe.get_stage_module(stage_idx)
stage = _PipelineStage(
stage_module=model,
stage_index=pp_rank,
pipe_info=pipe.pipe_info,
stage = PipelineStage(
pipe,
stage_index=stage_idx,
device=device,
group=pp_mesh.get_group(),
)
Expand Down
Loading