forked from pytorch/audio
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding a distributed pipeline parallelism example for RPC (pytorch#749)
Co-authored-by: Shen Li <shenli@devfair017.maas>
- Loading branch information
Showing
3 changed files
with
299 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
Distributed Pipeline Parallel Example | ||
|
||
This example shows how to distribute a ResNet50 model on two RPC workers and | ||
then implement distributed pipeline parallelism using RPC. | ||
|
||
``` | ||
pip install -r requirements.txt | ||
python main.py | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,288 @@ | ||
import os | ||
import threading | ||
import time | ||
from functools import wraps | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.distributed.autograd as dist_autograd | ||
import torch.distributed.rpc as rpc | ||
import torch.multiprocessing as mp | ||
import torch.optim as optim | ||
from torch.distributed.optim import DistributedOptimizer | ||
from torch.distributed.rpc import RRef | ||
|
||
from torchvision.models.resnet import Bottleneck | ||
|
||
|
||
######################################################### | ||
# helper functions # | ||
######################################################### | ||
|
||
|
||
def _call_method(method, rref, *args, **kwargs): | ||
r""" | ||
a helper function to call a method on the given RRef | ||
""" | ||
return method(rref.local_value(), *args, **kwargs) | ||
|
||
|
||
def _remote_on_rref(method, rref, *args, **kwargs): | ||
r""" | ||
a helper function to run method on the owner of rref and return an RRef | ||
of the result. | ||
""" | ||
return rpc.remote( | ||
rref.owner(), | ||
_call_method, | ||
args=[method, rref] + list(args), | ||
kwargs=kwargs | ||
) | ||
|
||
|
||
def _async_on_rref(method, rref, *args, **kwargs): | ||
r""" | ||
a helper function to run method on the owner of rref and fetch back the | ||
result using RPC | ||
""" | ||
return rpc.rpc_async( | ||
rref.owner(), | ||
_call_method, | ||
args=[method, rref] + list(args), | ||
kwargs=kwargs | ||
) | ||
|
||
|
||
def _parameter_rrefs(module): | ||
r""" | ||
Create one RRef for each parameter in the given local module, and return a | ||
list of RRefs. | ||
""" | ||
param_rrefs = [] | ||
for param in module.parameters(): | ||
param_rrefs.append(RRef(param)) | ||
return param_rrefs | ||
|
||
|
||
######################################################### | ||
# Define Model Parallel ResNet50 # | ||
######################################################### | ||
|
||
|
||
num_classes = 1000 | ||
|
||
|
||
def conv1x1(in_planes, out_planes, stride=1): | ||
"""1x1 convolution""" | ||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) | ||
|
||
|
||
class ResNetBase(nn.Module): | ||
def __init__(self, block, inplanes, num_classes=1000, | ||
groups=1, width_per_group=64, norm_layer=None): | ||
super(ResNetBase, self).__init__() | ||
|
||
self._lock = threading.Lock() | ||
self._block = block | ||
self._norm_layer = nn.BatchNorm2d | ||
self.inplanes = inplanes | ||
self.dilation = 1 | ||
self.groups = groups | ||
self.base_width = width_per_group | ||
|
||
def _make_layer(self, planes, blocks, stride=1): | ||
norm_layer = self._norm_layer | ||
downsample = None | ||
previous_dilation = self.dilation | ||
if stride != 1 or self.inplanes != planes * self._block.expansion: | ||
downsample = nn.Sequential( | ||
conv1x1(self.inplanes, planes * self._block.expansion, stride), | ||
norm_layer(planes * self._block.expansion), | ||
) | ||
|
||
layers = [] | ||
layers.append(self._block(self.inplanes, planes, stride, downsample, self.groups, | ||
self.base_width, previous_dilation, norm_layer)) | ||
self.inplanes = planes * self._block.expansion | ||
for _ in range(1, blocks): | ||
layers.append(self._block(self.inplanes, planes, groups=self.groups, | ||
base_width=self.base_width, dilation=self.dilation, | ||
norm_layer=norm_layer)) | ||
|
||
return nn.Sequential(*layers) | ||
|
||
|
||
class ResNetPart1(ResNetBase): | ||
""" | ||
The first part of ResNet. | ||
""" | ||
def __init__(self, device, *args, **kwargs): | ||
super(ResNetPart1, self).__init__( | ||
Bottleneck, 64, num_classes=num_classes, *args, **kwargs) | ||
|
||
self.device = device | ||
self.seq = nn.Sequential( | ||
nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False), | ||
self._norm_layer(self.inplanes), | ||
nn.ReLU(inplace=True), | ||
nn.MaxPool2d(kernel_size=3, stride=2, padding=1), | ||
self._make_layer(64, 3), | ||
self._make_layer(128, 4, stride=2) | ||
).to(self.device) | ||
|
||
for m in self.modules(): | ||
if isinstance(m, nn.Conv2d): | ||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | ||
elif isinstance(m, nn.BatchNorm2d): | ||
nn.init.constant_(m.weight, 1) | ||
nn.init.constant_(m.bias, 0) | ||
|
||
def forward(self, x_rref): | ||
x = x_rref.to_here().to(self.device) | ||
with self._lock: | ||
out = self.seq(x) | ||
return out.cpu() | ||
|
||
|
||
class ResNetPart2(ResNetBase): | ||
""" | ||
The second part of ResNet. | ||
""" | ||
def __init__(self, device, *args, **kwargs): | ||
super(ResNetPart2, self).__init__( | ||
Bottleneck, 512, num_classes=num_classes, *args, **kwargs) | ||
|
||
self.device = device | ||
self.seq = nn.Sequential( | ||
self._make_layer(256, 6, stride=2), | ||
self._make_layer(512, 3, stride=2), | ||
nn.AdaptiveAvgPool2d((1, 1)), | ||
).to(self.device) | ||
|
||
self.fc = nn.Linear(512 * self._block.expansion, num_classes).to(self.device) | ||
|
||
def forward(self, x_rref): | ||
x = x_rref.to_here().to(self.device) | ||
with self._lock: | ||
out = self.fc(torch.flatten(self.seq(x), 1)) | ||
return out.cpu() | ||
|
||
|
||
class DistResNet50(nn.Module): | ||
""" | ||
Assemble two parts as an nn.Module and define pipelining logic | ||
""" | ||
def __init__(self, split_size, workers, *args, **kwargs): | ||
super(DistResNet50, self).__init__() | ||
|
||
self.split_size = split_size | ||
|
||
# Put the first part of the ResNet50 on workers[0] | ||
self.p1_rref = rpc.remote( | ||
workers[0], | ||
ResNetPart1, | ||
args = ("cuda:0",) + args, | ||
kwargs = kwargs | ||
) | ||
|
||
# Put the second part of the ResNet50 on workers[1] | ||
self.p2_rref = rpc.remote( | ||
workers[1], | ||
ResNetPart2, | ||
args = ("cuda:1",) + args, | ||
kwargs = kwargs | ||
) | ||
|
||
def forward(self, xs): | ||
# Split the input batch xs into micro-batches, and collect async RPC | ||
# futures into a list | ||
out_futures = [] | ||
for x in iter(xs.split(self.split_size, dim=0)): | ||
x_rref = RRef(x) | ||
y_rref = _remote_on_rref(ResNetPart1.forward, self.p1_rref, x_rref) | ||
z_fut = _async_on_rref(ResNetPart2.forward, self.p2_rref, y_rref) | ||
out_futures.append(z_fut) | ||
|
||
# wait for all RPC to finish | ||
outs = [fut.wait() for fut in out_futures] | ||
# cat all tensors into one tensor. | ||
out = torch.cat(outs) | ||
return out | ||
|
||
def parameter_rrefs(self): | ||
remote_params = [] | ||
remote_params.extend(_remote_on_rref(_parameter_rrefs, self.p1_rref).to_here()) | ||
remote_params.extend(_remote_on_rref(_parameter_rrefs, self.p2_rref).to_here()) | ||
return remote_params | ||
|
||
|
||
######################################################### | ||
# Run RPC Processes # | ||
######################################################### | ||
|
||
num_batches = 3 | ||
batch_size = 120 | ||
image_w = 128 | ||
image_h = 128 | ||
|
||
|
||
def run_master(num_split): | ||
# put the two model parts on worker1 and worker2 respectively | ||
model = DistResNet50(split_size, ["worker1", "worker2"]) | ||
loss_fn = nn.MSELoss() | ||
opt = DistributedOptimizer( | ||
optim.SGD, | ||
model.parameter_rrefs(), | ||
lr=0.05, | ||
) | ||
|
||
one_hot_indices = torch.LongTensor(batch_size) \ | ||
.random_(0, num_classes) \ | ||
.view(batch_size, 1) | ||
|
||
for i in range(num_batches): | ||
print(f"Processing batch {i}") | ||
# generate random inputs and labels | ||
inputs = torch.randn(batch_size, 3, image_w, image_h) | ||
labels = torch.zeros(batch_size, num_classes) \ | ||
.scatter_(1, one_hot_indices, 1) | ||
|
||
with dist_autograd.context() as context_id: | ||
outputs = model(inputs) | ||
dist_autograd.backward(context_id, [loss_fn(outputs, labels)]) | ||
opt.step(context_id) | ||
|
||
|
||
def run_worker(rank, world_size, num_split): | ||
os.environ['MASTER_ADDR'] = 'localhost' | ||
os.environ['MASTER_PORT'] = '29500' | ||
options = rpc.ProcessGroupRpcBackendOptions(num_send_recv_threads=256) | ||
|
||
if rank == 0: | ||
rpc.init_rpc( | ||
"master", | ||
rank=rank, | ||
world_size=world_size, | ||
rpc_backend_options=options | ||
) | ||
run_master(num_split) | ||
else: | ||
rpc.init_rpc( | ||
f"worker{rank}", | ||
rank=rank, | ||
world_size=world_size, | ||
rpc_backend_options=options | ||
) | ||
pass | ||
|
||
# block until all rpcs finish | ||
rpc.shutdown() | ||
|
||
|
||
if __name__=="__main__": | ||
world_size = 3 | ||
for num_split in [1, 2, 4, 8]: | ||
tik = time.time() | ||
mp.spawn(run_worker, args=(world_size, num_split), nprocs=world_size, join=True) | ||
tok = time.time() | ||
print(f"number of splits = {num_split}, execution time = {tok - tik}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
torch==1.5.0 | ||
torchvision==0.6.0 |