Skip to content

Commit

Permalink
Adding a distributed pipeline parallelism example for RPC (pytorch#749)
Browse files Browse the repository at this point in the history
Co-authored-by: Shen Li <shenli@devfair017.maas>
  • Loading branch information
mrshenli and Shen Li authored Apr 23, 2020
1 parent 391be73 commit d431037
Show file tree
Hide file tree
Showing 3 changed files with 299 additions and 0 deletions.
9 changes: 9 additions & 0 deletions distributed/rpc/pipeline/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Distributed Pipeline Parallel Example

This example shows how to distribute a ResNet50 model on two RPC workers and
then implement distributed pipeline parallelism using RPC.

```
pip install -r requirements.txt
python main.py
```
288 changes: 288 additions & 0 deletions distributed/rpc/pipeline/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
import os
import threading
import time
from functools import wraps

import torch
import torch.nn as nn
import torch.distributed.autograd as dist_autograd
import torch.distributed.rpc as rpc
import torch.multiprocessing as mp
import torch.optim as optim
from torch.distributed.optim import DistributedOptimizer
from torch.distributed.rpc import RRef

from torchvision.models.resnet import Bottleneck


#########################################################
# helper functions #
#########################################################


def _call_method(method, rref, *args, **kwargs):
r"""
a helper function to call a method on the given RRef
"""
return method(rref.local_value(), *args, **kwargs)


def _remote_on_rref(method, rref, *args, **kwargs):
r"""
a helper function to run method on the owner of rref and return an RRef
of the result.
"""
return rpc.remote(
rref.owner(),
_call_method,
args=[method, rref] + list(args),
kwargs=kwargs
)


def _async_on_rref(method, rref, *args, **kwargs):
r"""
a helper function to run method on the owner of rref and fetch back the
result using RPC
"""
return rpc.rpc_async(
rref.owner(),
_call_method,
args=[method, rref] + list(args),
kwargs=kwargs
)


def _parameter_rrefs(module):
r"""
Create one RRef for each parameter in the given local module, and return a
list of RRefs.
"""
param_rrefs = []
for param in module.parameters():
param_rrefs.append(RRef(param))
return param_rrefs


#########################################################
# Define Model Parallel ResNet50 #
#########################################################


num_classes = 1000


def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class ResNetBase(nn.Module):
def __init__(self, block, inplanes, num_classes=1000,
groups=1, width_per_group=64, norm_layer=None):
super(ResNetBase, self).__init__()

self._lock = threading.Lock()
self._block = block
self._norm_layer = nn.BatchNorm2d
self.inplanes = inplanes
self.dilation = 1
self.groups = groups
self.base_width = width_per_group

def _make_layer(self, planes, blocks, stride=1):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if stride != 1 or self.inplanes != planes * self._block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * self._block.expansion, stride),
norm_layer(planes * self._block.expansion),
)

layers = []
layers.append(self._block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * self._block.expansion
for _ in range(1, blocks):
layers.append(self._block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))

return nn.Sequential(*layers)


class ResNetPart1(ResNetBase):
"""
The first part of ResNet.
"""
def __init__(self, device, *args, **kwargs):
super(ResNetPart1, self).__init__(
Bottleneck, 64, num_classes=num_classes, *args, **kwargs)

self.device = device
self.seq = nn.Sequential(
nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False),
self._norm_layer(self.inplanes),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
self._make_layer(64, 3),
self._make_layer(128, 4, stride=2)
).to(self.device)

for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)

def forward(self, x_rref):
x = x_rref.to_here().to(self.device)
with self._lock:
out = self.seq(x)
return out.cpu()


class ResNetPart2(ResNetBase):
"""
The second part of ResNet.
"""
def __init__(self, device, *args, **kwargs):
super(ResNetPart2, self).__init__(
Bottleneck, 512, num_classes=num_classes, *args, **kwargs)

self.device = device
self.seq = nn.Sequential(
self._make_layer(256, 6, stride=2),
self._make_layer(512, 3, stride=2),
nn.AdaptiveAvgPool2d((1, 1)),
).to(self.device)

self.fc = nn.Linear(512 * self._block.expansion, num_classes).to(self.device)

def forward(self, x_rref):
x = x_rref.to_here().to(self.device)
with self._lock:
out = self.fc(torch.flatten(self.seq(x), 1))
return out.cpu()


class DistResNet50(nn.Module):
"""
Assemble two parts as an nn.Module and define pipelining logic
"""
def __init__(self, split_size, workers, *args, **kwargs):
super(DistResNet50, self).__init__()

self.split_size = split_size

# Put the first part of the ResNet50 on workers[0]
self.p1_rref = rpc.remote(
workers[0],
ResNetPart1,
args = ("cuda:0",) + args,
kwargs = kwargs
)

# Put the second part of the ResNet50 on workers[1]
self.p2_rref = rpc.remote(
workers[1],
ResNetPart2,
args = ("cuda:1",) + args,
kwargs = kwargs
)

def forward(self, xs):
# Split the input batch xs into micro-batches, and collect async RPC
# futures into a list
out_futures = []
for x in iter(xs.split(self.split_size, dim=0)):
x_rref = RRef(x)
y_rref = _remote_on_rref(ResNetPart1.forward, self.p1_rref, x_rref)
z_fut = _async_on_rref(ResNetPart2.forward, self.p2_rref, y_rref)
out_futures.append(z_fut)

# wait for all RPC to finish
outs = [fut.wait() for fut in out_futures]
# cat all tensors into one tensor.
out = torch.cat(outs)
return out

def parameter_rrefs(self):
remote_params = []
remote_params.extend(_remote_on_rref(_parameter_rrefs, self.p1_rref).to_here())
remote_params.extend(_remote_on_rref(_parameter_rrefs, self.p2_rref).to_here())
return remote_params


#########################################################
# Run RPC Processes #
#########################################################

num_batches = 3
batch_size = 120
image_w = 128
image_h = 128


def run_master(num_split):
# put the two model parts on worker1 and worker2 respectively
model = DistResNet50(split_size, ["worker1", "worker2"])
loss_fn = nn.MSELoss()
opt = DistributedOptimizer(
optim.SGD,
model.parameter_rrefs(),
lr=0.05,
)

one_hot_indices = torch.LongTensor(batch_size) \
.random_(0, num_classes) \
.view(batch_size, 1)

for i in range(num_batches):
print(f"Processing batch {i}")
# generate random inputs and labels
inputs = torch.randn(batch_size, 3, image_w, image_h)
labels = torch.zeros(batch_size, num_classes) \
.scatter_(1, one_hot_indices, 1)

with dist_autograd.context() as context_id:
outputs = model(inputs)
dist_autograd.backward(context_id, [loss_fn(outputs, labels)])
opt.step(context_id)


def run_worker(rank, world_size, num_split):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
options = rpc.ProcessGroupRpcBackendOptions(num_send_recv_threads=256)

if rank == 0:
rpc.init_rpc(
"master",
rank=rank,
world_size=world_size,
rpc_backend_options=options
)
run_master(num_split)
else:
rpc.init_rpc(
f"worker{rank}",
rank=rank,
world_size=world_size,
rpc_backend_options=options
)
pass

# block until all rpcs finish
rpc.shutdown()


if __name__=="__main__":
world_size = 3
for num_split in [1, 2, 4, 8]:
tik = time.time()
mp.spawn(run_worker, args=(world_size, num_split), nprocs=world_size, join=True)
tok = time.time()
print(f"number of splits = {num_split}, execution time = {tok - tik}")
2 changes: 2 additions & 0 deletions distributed/rpc/pipeline/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
torch==1.5.0
torchvision==0.6.0

0 comments on commit d431037

Please sign in to comment.