Skip to content

Commit

Permalink
[feat] pp 2/n (Lightning-AI#5026)
Browse files Browse the repository at this point in the history
* Added changes for RPC plugin

* Add missing kwargs

* Fix code format

* Loading refactors by introducing is_distributed var, fix optimizer step flow

* Add rpc guard

* Added docstrings and typing

* resolve comments

* Add additional rpc hook, refactor name of exit process hook for clarity

* remove annotation

* Modify behaviour to allow optional return, add test for rpc plugin

* resolve tests

* rename is_ddp_based

* update

* update for windows

* update

* resolve test

* code smell

* Added sequential plugin

* resolve bug

* update

* cleanup

* add Exception

* resolve docs

* Remove ddp support

* Revert distributed -> ddp

* Update pl_examples/basic_examples/conv_sequential_example.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update pl_examples/basic_examples/conv_sequential_example.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/plugins/ddp_sequential_plugin.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Address code review points

* Update pytorch_lightning/plugins/ddp_sequential_plugin.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/plugins/ddp_sequential_plugin.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Add missing return

* Fix formatting, add datamodule args

* add small comment

* resolve comments

* resolve comments

* update source for fairscale

* update extras

* remove staticmethod

* resolve flake8

* Skip tests that are failing due to bug upstream with multiple optimizers and shard

* update

* update on comments

* clean test

* latest comments

* remove old comments

* add todo

* Update version

* update

* resolve bugs

* resolve bugs

* update test

* remove hanging test

* Update pytorch_lightning/plugins/ddp_sequential_plugin.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* resolve on comments

* Update pytorch_lightning/plugins/ddp_sequential_plugin.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* resolve on comments

* Update pytorch_lightning/plugins/ddp_sequential_plugin.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* Update pytorch_lightning/plugins/ddp_sequential_plugin.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* Update pytorch_lightning/plugins/ddp_sequential_plugin.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* Update pytorch_lightning/plugins/ddp_sequential_plugin.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* remove ImportError

Co-authored-by: SeanNaren <sean@grid.ai>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
  • Loading branch information
5 people authored Dec 9, 2020
1 parent 7d9784e commit ef8ef12
Show file tree
Hide file tree
Showing 13 changed files with 881 additions and 27 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,6 @@ repos:
types: [python]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: master
hooks:
- id: mypy
4 changes: 3 additions & 1 deletion benchmarks/test_sharded_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu_ddp(tmpdir, args=None):
)


@pytest.mark.skip(reason="Current issue with multiple optimizers and FairScale.")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
Expand All @@ -148,6 +149,7 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim():
)


@pytest.mark.skip(reason="Current issue with multiple optimizers and FairScale.")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
Expand Down Expand Up @@ -189,7 +191,7 @@ def training_step(self, batch, batch_idx, optimizer_idx):

# ensure we forward the correct params to the optimizer
# without retain_graph we can't do multiple backward passes
self.manual_backward(loss_2, opt_b, retain_graph=True)
self.manual_backward(loss_2, opt_b)
# todo: understand why synchronization breaks there.
# self.manual_backward(loss_2, opt_a, retain_graph=True)
opt_b.step()
Expand Down
216 changes: 216 additions & 0 deletions pl_examples/basic_examples/conv_sequential_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Example script of running the experimental DDP Sequential Plugin.
This script splits a convolutional model onto multiple GPUs, whilst using the internal built in balancer
to balance across your GPUs.
To run:
python conv_model_sequential_example.py --accelerator ddp --gpus 4 --max_epochs 1 --batch_size 256 --use_ddp_sequential
"""
import math
from argparse import ArgumentParser

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.metrics.functional import accuracy
from pytorch_lightning.plugins.ddp_sequential_plugin import DDPSequentialPlugin
from pytorch_lightning.utilities import BOLTS_AVAILABLE, FAIRSCALE_PIPE_AVAILABLE

if BOLTS_AVAILABLE:
import pl_bolts
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization


#####################
# Modules #
#####################


class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)

###############################
# LightningModule #
###############################


class LitResnet(pl.LightningModule):
def __init__(self, lr=0.05, batch_size=32, manual_optimization=False):
super().__init__()

self.save_hyperparameters()
self.sequential_module = nn.Sequential(
# Conv Layer block 1
nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=False),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
nn.ReLU(inplace=False),
nn.MaxPool2d(kernel_size=2, stride=2),

# Conv Layer block 2
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=False),
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
nn.ReLU(inplace=False),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Dropout2d(p=0.05),

# Conv Layer block 3
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=False),
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
nn.ReLU(inplace=False),
nn.MaxPool2d(kernel_size=2, stride=2),

Flatten(),

nn.Dropout(p=0.1),
nn.Linear(4096, 1024),
nn.ReLU(inplace=False),
nn.Linear(1024, 512),
nn.ReLU(inplace=False),
nn.Dropout(p=0.1),
nn.Linear(512, 10)
)
self._example_input_array = torch.randn((1, 3, 32, 32))
self._manual_optimization = manual_optimization
if self._manual_optimization:
self.training_step = self.training_step_manual

def forward(self, x):
out = self.sequential_module(x)
return F.log_softmax(out, dim=-1)

def training_step_manual(self, batch, batch_idx):
opt = self.optimizers()

def closure():
x, y = batch
logits = self.forward(x)
loss = F.nll_loss(logits, y)
self.manual_backward(loss, opt)
self.log('train_loss', loss, prog_bar=True)

opt.step(closure=closure)

def training_step(self, batch, batch_idx):
x, y = batch
logits = self.forward(x)
loss = F.nll_loss(logits, y)
self.log('Training Loss', loss)
return loss

def _evaluate(self, batch, batch_idx, stage=None):
x, y = batch
out = self.forward(x)
logits = F.log_softmax(out, dim=-1)
loss = F.nll_loss(logits, y)
preds = torch.argmax(logits, dim=-1)
acc = accuracy(preds, y)

if stage:
self.log(f'{stage}_loss', loss, prog_bar=True)
self.log(f'{stage}_acc', acc, prog_bar=True)

return loss, acc

def validation_step(self, batch, batch_idx):
return self._evaluate(batch, batch_idx, 'val')[0]

def test_step(self, batch, batch_idx):
loss, acc = self._evaluate(batch, batch_idx, 'test')
self.log_dict({'test_loss': loss, 'test_acc': acc})

def configure_optimizers(self):
optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr, momentum=0.9, weight_decay=5e-4)
return {
'optimizer': optimizer,
'lr_scheduler': {
'scheduler': torch.optim.lr_scheduler.OneCycleLR(
optimizer,
0.1,
epochs=self.trainer.max_epochs,
steps_per_epoch=math.ceil(45000 / self.hparams.batch_size)),
'interval': 'step',
}
}

@property
def automatic_optimization(self) -> bool:
return not self._manual_optimization


#################################
# Instantiate Data Module #
#################################

def instantiate_datamodule(args):
train_transforms = torchvision.transforms.Compose([
torchvision.transforms.RandomCrop(32, padding=4),
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.ToTensor(),
cifar10_normalization(),
])

test_transforms = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
cifar10_normalization(),
])

cifar10_dm = pl_bolts.datamodules.CIFAR10DataModule(
batch_size=args.batch_size,
train_transforms=train_transforms,
test_transforms=test_transforms,
val_transforms=test_transforms,
)

return cifar10_dm


if __name__ == "__main__":
parser = ArgumentParser(description="Pipe Example")
parser.add_argument("--use_ddp_sequential", action="store_true")
parser = Trainer.add_argparse_args(parser)
parser = pl_bolts.datamodules.CIFAR10DataModule.add_argparse_args(parser)
args = parser.parse_args()

assert BOLTS_AVAILABLE, "Bolts is required for this example, install it via pip install pytorch-lightning-bolts"
assert FAIRSCALE_PIPE_AVAILABLE, "FairScale and PyTorch 1.6 is required for this example."

cifar10_dm = instantiate_datamodule(args)

plugins = None
if args.use_ddp_sequential:
plugins = DDPSequentialPlugin()

model = LitResnet(batch_size=args.batch_size, manual_optimization=not args.automatic_optimization)

trainer = pl.Trainer.from_argparse_args(args, plugins=[plugins] if plugins else None)
trainer.fit(model, cifar10_dm)
trainer.test(model, datamodule=cifar10_dm)

if trainer.accelerator_backend.rpc_enabled:
# Called at the end of trainer to ensure all processes are killed
trainer.accelerator_backend.ddp_plugin.exit_rpc_process()
4 changes: 3 additions & 1 deletion pytorch_lightning/overrides/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ class LightningDistributedDataParallel(DistributedDataParallel):
"""
Override the forward call in lightning so it goes to training and validation step respectively
"""
PREPARE_FOR_BACKWARDS = True

def parallel_apply(self, replicas, inputs, kwargs):
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
Expand All @@ -165,6 +166,7 @@ def forward(self, *inputs, **kwargs): # pragma: no-cover
fx_called: str = ''

if self.device_ids:

inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
# --------------
Expand Down Expand Up @@ -195,7 +197,7 @@ def forward(self, *inputs, **kwargs): # pragma: no-cover
else:
output = self.module.validation_step(*inputs, **kwargs)

if not self._reducer_prepared_for_backwards:
if not self._reducer_prepared_for_backwards and self.PREPARE_FOR_BACKWARDS:
self.reducer_prepare_for_backwards(output)

if output is None:
Expand Down
Loading

0 comments on commit ef8ef12

Please sign in to comment.