Skip to content

Commit

Permalink
Update TP examples to align with tutorials
Browse files Browse the repository at this point in the history
as titled
  • Loading branch information
wanchaol committed Apr 11, 2024
1 parent 7df10c2 commit 9ca0137
Show file tree
Hide file tree
Showing 6 changed files with 685 additions and 71 deletions.
38 changes: 38 additions & 0 deletions .github/workflows/main_distributed.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
name: Run Distributed Examples

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
schedule:
# Every day at 3:00am
- cron: '0 3 * * *'


jobs:
test:

runs-on: 4-core-ubuntu-gpu-t4

steps:
- uses: actions/checkout@v2
- name: Set up Python 3.8
uses: actions/setup-python@v2
with:
python-version: 3.8
- name: Install PyTorch
run: |
python -m pip install --upgrade pip
pip install --pre torch -f https://download.pytorch.org/whl/nightly/cu118/torch_nightly.html
- name: Run Tests
run: |
./run_distributed_examples.sh "install_deps,run_all,clean"
- name: Open issue on failure
if: ${{ failure() && github.event_name == 'schedule' }}
uses: rishabhgupta/git-action-issue@v2
with:
token: ${{ secrets.GITHUB_TOKEN }}
title: Daily CI failed
body: Commit ${{ github.sha }} daily scheduled [CI run](https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}) failed, please check why
assignees: ''
16 changes: 8 additions & 8 deletions distributed/tensor_parallelism/README.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# PyTorch Tensor Parallelism for distributed training
# PyTorch native Tensor Parallel for distributed training

This example demonstrates SPMD Megatron-LM style tensor parallel by using
PyTorch native Tensor Parallelism APIs, which include:
This example demonstrates SPMD Megatron-LM style Tensor Parallel by using
PyTorch native Tensor Parallel APIs, which include:

1. High-level APIs for module-level parallelism with a dummy MLP model.
2. Model agnostic ops for `DistributedTensor`, such as `Linear` and `RELU`.
3. A E2E demo of tensor parallel for a given toy model (Forward/backward + optimization).
1. Simple module-level Tensor Parallelism on a dummy MLP model.
2. Simple module-level Tensor Parallelism with Sequence Parallel inputs/outputs on a dummy MLP model.
3. A E2E demo of Fully Sharded Data Parallel + Tensor Parallel (with Sequence Parallel) on a example Llama2 model.

More details about the design can be found:
https://github.com/pytorch/pytorch/issues/89884
More details about the PyTorch native Tensor Parallel APIs, please see PyTorch docs:
https://pytorch.org/docs/stable/distributed.tensor.parallel.html

```
pip install -r requirements.txt
Expand Down
142 changes: 80 additions & 62 deletions distributed/tensor_parallelism/fsdp_tp_example.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,12 @@
import sys
import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.tensor.parallel import (
parallelize_module,
ColwiseParallel,
RowwiseParallel,
)

import os
from log_utils import rank_log, get_logger, verify_min_gpu_count


# ---- GPU check ------------
_min_gpu_count = 4

Expand All @@ -23,13 +15,24 @@
sys.exit()
# ---------------------------

from torch.distributed._tensor.device_mesh import init_device_mesh
from llama2_model import Transformer, ModelArgs

from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed._tensor import Shard, Replicate
from torch.distributed.tensor.parallel import (
parallelize_module,
ColwiseParallel,
RowwiseParallel,
PrepareModuleInput,
SequenceParallel
)


"""
This is the script to test 2D Parallel which combines Tensor/Sequence
parallel with Fully Sharded Data Parallel (TP/SP + FSDP) on a toy model
in the SPMD style. We show an E2E working flow from forward, backward
parallel with Fully Sharded Data Parallel (TP/SP + FSDP) on a example
Llama2 model. We show an E2E working flow from forward, backward
and optimization.
We enabled Fully Sharded Data Parallel + Tensor Parallel in
Expand All @@ -53,41 +56,10 @@
[0, 8, ..., 8N-8], [1, 9, ..., 8N-7], ..., [7, 15, ..., 8N-1]
======================================================================
More details can be seen in the slide:
https://docs.google.com/presentation/d/17g6WqrO00rP3MsxbRENsPpjrlSkwiA_QB4r93_eB5is/
More details can be seen in the PyTorch tutorials:
https://pytorch.org/tutorials/intermediate/TP_tutorial.html
"""


def find_multiple(n: int, k: int) -> int:
"""function to find resizing multiple for SwiGLU MLP"""
if n % k == 0:
return n
return n + k - (n % k)


class MLP_swiglu(nn.Module):
"""SwiGLU to showcase a Llama style MLP model"""

def __init__(self, mlp_dim: int = 1024) -> None:
super().__init__()
hidden_dim = 4 * mlp_dim
scaled_hidden = int(2 * hidden_dim / 3)
rounded_hidden = find_multiple(scaled_hidden, 256)

self.in_proj = nn.Linear(mlp_dim, rounded_hidden, bias=False)
self.gate_proj = nn.Linear(mlp_dim, rounded_hidden, bias=False)
self.out_proj = nn.Linear(rounded_hidden, mlp_dim, bias=False)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.silu(self.in_proj(x)) * self.gate_proj(x)
x = self.out_proj(x)
return x


"""
Main body of the demo of a basic version of tensor parallel by using
PyTorch native APIs.
"""
tp_size = 2
logger = get_logger()

Expand Down Expand Up @@ -120,26 +92,72 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# to mimic the behavior of the dataloader.
dp_rank = dp_mesh.get_local_rank()

# create model and move it to GPU with id rank
_mlp_dim = 1024
base_model_swiglu = MLP_swiglu(mlp_dim=_mlp_dim).to("cuda")


# Custom parallelization plan for the swiglu MLP model
custom_tp_model = parallelize_module(
module=base_model_swiglu,
device_mesh=tp_mesh,
parallelize_plan={
"in_proj": ColwiseParallel(),
"gate_proj": ColwiseParallel(),
"out_proj": RowwiseParallel(),
},
# create model and move it to GPU - init"cuda"_mesh has already mapped GPU ids.
simple_llama2_config = ModelArgs(dim=256, n_layers=2, n_heads=16, vocab_size=32000)

model = Transformer.from_model_args(simple_llama2_config).to("cuda")

# init model weights
model.init_weights()

# parallelize the first embedding and the last linear out projection
model = parallelize_module(
model,
tp_mesh,
{
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
),
"output": ColwiseParallel(
input_layouts=Shard(1),
output_layouts=Replicate()
),
"norm": SequenceParallel(),
"layers.0": PrepareModuleInput(
input_layouts=(Replicate(), None),
desired_input_layouts=(Shard(1), None),
use_local_output=True,
),
}
)

rank_log(_rank, logger, f"Model after parallelization {custom_tp_model=}\n")
for layer_id, transformer_block in enumerate(model.layers):
layer_tp_plan = {
"attention": PrepareModuleInput(
input_layouts=(Shard(1), None),
desired_input_layouts=(Replicate(), None),
),
"attention.wq": ColwiseParallel(),
"attention.wk": ColwiseParallel(),
"attention.wv": ColwiseParallel(),
"attention.wo": RowwiseParallel(output_layouts=Shard(1)),
"attention_norm": SequenceParallel(),
"feed_forward": PrepareModuleInput(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
),
"feed_forward.w1": ColwiseParallel(),
"feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)),
"feed_forward.w3": ColwiseParallel(),
"ffn_norm": SequenceParallel(),
}

# Adjust attention module to use the local number of heads
attn_layer = transformer_block.attention
attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size()
attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size()

# Custom parallelization plan for the model
parallelize_module(
module=transformer_block,
device_mesh=tp_mesh,
parallelize_plan=layer_tp_plan
)

# Init FSDP using the dp device mesh
sharded_model = FSDP(custom_tp_model, device_mesh=dp_mesh, use_orig_params=True)
sharded_model = FSDP(model, device_mesh=dp_mesh, use_orig_params=True)

rank_log(_rank, logger, f"Model after parallelization {sharded_model=}\n")

# Create an optimizer for the parallelized and sharded model.
lr = 3e-3
Expand All @@ -156,7 +174,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
for i in range(num_iterations):
# seeding with dp_rank to ensure identical inputs for TP groups
torch.manual_seed(i + dp_rank)
inp = torch.rand(batch_size, _mlp_dim, device="cuda")
inp = torch.randint(32000, (8, 256), device="cuda")

output = sharded_model(inp)
output.sum().backward()
Expand Down
Loading

0 comments on commit 9ca0137

Please sign in to comment.