Skip to content

Commit

Permalink
Additional configs to enable 3d parallelism (pytorch#17)
Browse files Browse the repository at this point in the history
* enable PP and 3D

* add code to enable 3d parallelism

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

---------

Co-authored-by: Tianyu Liu <lty@fb.com>
  • Loading branch information
ruisizhang123 and tianyu-l authored Sep 20, 2024
1 parent 9790417 commit 88a0054
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 12 deletions.
28 changes: 25 additions & 3 deletions benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def main(job_config: JobConfig):
batch_size=job_config.training.batch_size,
extra_args=[],
)
model_flops = get_model_flops(config)
benchmark_model = load_model(config)
model, _ = benchmark_model.get_module()

Expand All @@ -104,12 +103,35 @@ def main(job_config: JobConfig):
)

# apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
model = torch_spmd_parallelize(model, world_mesh, parallel_dims, job_config)
if job_config.experimental.torch_spmd:
# apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
model = torch_spmd_parallelize(model, world_mesh, parallel_dims, job_config)
else:
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy

param_dtype = (TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],)
reduce_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce]
mp_policy = MixedPrecisionPolicy(
param_dtype=param_dtype, reduce_dtype=reduce_dtype
)
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}

for name, block in model.named_children():
block = torch.compile(block)
model.register_module(name, block)

for name, block in model.named_children():
fully_shard(
block,
**fsdp_config,
reshard_after_forward=True,
)
fully_shard(model, **fsdp_config, reshard_after_forward=True)
# update model and optimizer after applying parallelisms
benchmark_model.set_module(model)
optimizer = benchmark_model.get_optimizer()
optimizer.add_param_group({"params": model.parameters()})
if optimizer is not None:
optimizer.add_param_group({"params": model.parameters()})

model.train()

Expand Down
4 changes: 3 additions & 1 deletion run_llama_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ if [ $# -ne 0 ]; then
overrides="$*"
fi

# TORCH_TRACE="./outputs/trace" \
# TORCH_NCCL_AVOID_RECORD_STREAMS=1: clear AG/RS copy-in copy-out memory after use
# TORCHINDUCTOR_FORCE_DISABLE_CACHES=1: avoid TP crash during compilation
TORCH_NCCL_AVOID_RECORD_STREAMS=1 \
TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 \
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
train.py --job.config_file ${CONFIG_FILE} $overrides
30 changes: 28 additions & 2 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,31 @@ def torch_spmd_parallelize(
parallel_dims: ParallelDims,
job_config: JobConfig,
):
# ensure full graph compile on non-llama models
torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.capture_dynamic_output_shape_ops = True

# simple fsdp configs
torch._inductor.config.simplefsdp.bucket_mode = "greedy"
torch._inductor.config.simplefsdp.enable_reorder = True
torch._inductor.config.simplefsdp.enable_bucket = True
torch._inductor.config.simplefsdp.degree = parallel_dims.dp
torch._inductor.config.simplefsdp.tp_enabled = parallel_dims.tp_enabled
torch._inductor.config.simplefsdp.pp_enabled = parallel_dims.pp_enabled
torch._inductor.config.simplefsdp.device_mesh = world_mesh.mesh.tolist()

print("enable reorder", torch._inductor.config.simplefsdp.enable_reorder)
if torch._inductor.config.simplefsdp.bucket_mode == "transformer_block":
print("enable block-level bucket")
elif torch._inductor.config.simplefsdp.bucket_mode == "greedy":
print("enable greedy auto bucket")
print(
"ag_comm_time_multiplier",
torch._inductor.config.simplefsdp.ag_comm_time_multiplier,
)
print(
"rs_comm_time_multiplier",
torch._inductor.config.simplefsdp.rs_comm_time_multiplier,
)

if parallel_dims.tp_enabled:
apply_tp(
Expand Down Expand Up @@ -329,6 +352,8 @@ def apply_compile(model: nn.Module):
Apply torch.compile to each TransformerBlock, which makes compilation efficient due to
repeated structure. Alternatively one can compile the whole model (after applying DP).
"""
# TODO(ruisizhang123): avoid recompiling error on 70B model
torch._dynamo.config.cache_size_limit = 128
for layer_id, transformer_block in model.layers.named_children():
transformer_block = torch.compile(transformer_block, fullgraph=True)
model.layers.register_module(layer_id, transformer_block)
Expand Down Expand Up @@ -361,7 +386,8 @@ def apply_fsdp(
if pp_enabled:
# For PP, do not reshard after forward to avoid per-microbatch
# all-gathers, which can be expensive and non-overlapped
reshard_after_forward = False
# NOTE: set reshard_after_forward to True for fair comparison with simple FSDP
reshard_after_forward = True
else:
# As an optimization, do not reshard after forward for the last
# transformer block since FSDP would prefetch it immediately
Expand Down
6 changes: 5 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from datetime import timedelta

import torch
from torch.distributed._tensor import DTensor
from torch.distributed.elastic.multiprocessing.errors import record

# context needed by meta-init with torch_spmd
Expand Down Expand Up @@ -139,6 +140,9 @@ def main(job_config: JobConfig):

# loss function to be shared by Pipeline Parallel and SPMD training
def loss_fn(pred, labels):
# TODO(ruisizhang123): temporary fix to enable async TP for full model compile
if isinstance(pred, DTensor):
pred._local_tensor = pred._local_tensor.contiguous()
return torch.nn.functional.cross_entropy(
pred.flatten(0, 1), labels.flatten(0, 1)
)
Expand All @@ -162,7 +166,7 @@ def loss_fn(pred, labels):
m.train()

# TODO(lty): need to find a better way to apply torch.compile
if job_config.training.compile:
if job_config.training.compile and job_config.experimental.torch_spmd:
stages[i].submod = torch.compile(m, fullgraph=True)
else:
# apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
Expand Down
4 changes: 0 additions & 4 deletions train_configs/benchmark_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,6 @@ max_norm = 1.0 # grad norm clipping
steps = 10
data_parallel_degree = -1
compile = true
mixed_precision_param = "bfloat16"
mixed_precision_reduce = "bfloat16"
# mixed_precision_param = "float32"
# mixed_precision_reduce = "float32"

[experimental]
torch_spmd = true
Expand Down
3 changes: 2 additions & 1 deletion train_configs/llama3_70b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ compile = false
dataset = "c4"

[experimental]
pipeline_parallel_degree = 1
#pipeline_parallel_degree = 8
#pipeline_parallel_split_points=["layers.10","layers.20","layers.30","layers.40","layers.50","layers.60","layers.70"]

[checkpoint]
enable_checkpoint = false
Expand Down

0 comments on commit 88a0054

Please sign in to comment.