forked from Lightning-AI/pytorch-lightning
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Added changes for RPC plugin * Add missing kwargs * Fix code format * Loading refactors by introducing is_distributed var, fix optimizer step flow * Add rpc guard * Added docstrings and typing * resolve comments * Add additional rpc hook, refactor name of exit process hook for clarity * remove annotation * Modify behaviour to allow optional return, add test for rpc plugin * resolve tests * rename is_ddp_based * update * update for windows * update * resolve test * code smell * Added sequential plugin * resolve bug * update * cleanup * add Exception * resolve docs * Remove ddp support * Revert distributed -> ddp * Update pl_examples/basic_examples/conv_sequential_example.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/basic_examples/conv_sequential_example.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/plugins/ddp_sequential_plugin.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Address code review points * Update pytorch_lightning/plugins/ddp_sequential_plugin.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/plugins/ddp_sequential_plugin.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Add missing return * Fix formatting, add datamodule args * add small comment * resolve comments * resolve comments * update source for fairscale * update extras * remove staticmethod * resolve flake8 * Skip tests that are failing due to bug upstream with multiple optimizers and shard * update * update on comments * clean test * latest comments * remove old comments * add todo * Update version * update * resolve bugs * resolve bugs * update test * remove hanging test * Update pytorch_lightning/plugins/ddp_sequential_plugin.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * resolve on comments * Update pytorch_lightning/plugins/ddp_sequential_plugin.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * resolve on comments * Update pytorch_lightning/plugins/ddp_sequential_plugin.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * Update pytorch_lightning/plugins/ddp_sequential_plugin.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * Update pytorch_lightning/plugins/ddp_sequential_plugin.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * Update pytorch_lightning/plugins/ddp_sequential_plugin.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * remove ImportError Co-authored-by: SeanNaren <sean@grid.ai> Co-authored-by: Sean Naren <sean.narenthiran@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
- Loading branch information
1 parent
7d9784e
commit ef8ef12
Showing
13 changed files
with
881 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,5 +32,6 @@ repos: | |
types: [python] | ||
|
||
- repo: https://github.com/pre-commit/mirrors-mypy | ||
rev: master | ||
hooks: | ||
- id: mypy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,216 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" | ||
Example script of running the experimental DDP Sequential Plugin. | ||
This script splits a convolutional model onto multiple GPUs, whilst using the internal built in balancer | ||
to balance across your GPUs. | ||
To run: | ||
python conv_model_sequential_example.py --accelerator ddp --gpus 4 --max_epochs 1 --batch_size 256 --use_ddp_sequential | ||
""" | ||
import math | ||
from argparse import ArgumentParser | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torchvision | ||
|
||
import pytorch_lightning as pl | ||
from pytorch_lightning import Trainer | ||
from pytorch_lightning.metrics.functional import accuracy | ||
from pytorch_lightning.plugins.ddp_sequential_plugin import DDPSequentialPlugin | ||
from pytorch_lightning.utilities import BOLTS_AVAILABLE, FAIRSCALE_PIPE_AVAILABLE | ||
|
||
if BOLTS_AVAILABLE: | ||
import pl_bolts | ||
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization | ||
|
||
|
||
##################### | ||
# Modules # | ||
##################### | ||
|
||
|
||
class Flatten(nn.Module): | ||
def forward(self, x): | ||
return x.view(x.size(0), -1) | ||
|
||
############################### | ||
# LightningModule # | ||
############################### | ||
|
||
|
||
class LitResnet(pl.LightningModule): | ||
def __init__(self, lr=0.05, batch_size=32, manual_optimization=False): | ||
super().__init__() | ||
|
||
self.save_hyperparameters() | ||
self.sequential_module = nn.Sequential( | ||
# Conv Layer block 1 | ||
nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1), | ||
nn.BatchNorm2d(32), | ||
nn.ReLU(inplace=False), | ||
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1), | ||
nn.ReLU(inplace=False), | ||
nn.MaxPool2d(kernel_size=2, stride=2), | ||
|
||
# Conv Layer block 2 | ||
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1), | ||
nn.BatchNorm2d(128), | ||
nn.ReLU(inplace=False), | ||
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1), | ||
nn.ReLU(inplace=False), | ||
nn.MaxPool2d(kernel_size=2, stride=2), | ||
nn.Dropout2d(p=0.05), | ||
|
||
# Conv Layer block 3 | ||
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1), | ||
nn.BatchNorm2d(256), | ||
nn.ReLU(inplace=False), | ||
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1), | ||
nn.ReLU(inplace=False), | ||
nn.MaxPool2d(kernel_size=2, stride=2), | ||
|
||
Flatten(), | ||
|
||
nn.Dropout(p=0.1), | ||
nn.Linear(4096, 1024), | ||
nn.ReLU(inplace=False), | ||
nn.Linear(1024, 512), | ||
nn.ReLU(inplace=False), | ||
nn.Dropout(p=0.1), | ||
nn.Linear(512, 10) | ||
) | ||
self._example_input_array = torch.randn((1, 3, 32, 32)) | ||
self._manual_optimization = manual_optimization | ||
if self._manual_optimization: | ||
self.training_step = self.training_step_manual | ||
|
||
def forward(self, x): | ||
out = self.sequential_module(x) | ||
return F.log_softmax(out, dim=-1) | ||
|
||
def training_step_manual(self, batch, batch_idx): | ||
opt = self.optimizers() | ||
|
||
def closure(): | ||
x, y = batch | ||
logits = self.forward(x) | ||
loss = F.nll_loss(logits, y) | ||
self.manual_backward(loss, opt) | ||
self.log('train_loss', loss, prog_bar=True) | ||
|
||
opt.step(closure=closure) | ||
|
||
def training_step(self, batch, batch_idx): | ||
x, y = batch | ||
logits = self.forward(x) | ||
loss = F.nll_loss(logits, y) | ||
self.log('Training Loss', loss) | ||
return loss | ||
|
||
def _evaluate(self, batch, batch_idx, stage=None): | ||
x, y = batch | ||
out = self.forward(x) | ||
logits = F.log_softmax(out, dim=-1) | ||
loss = F.nll_loss(logits, y) | ||
preds = torch.argmax(logits, dim=-1) | ||
acc = accuracy(preds, y) | ||
|
||
if stage: | ||
self.log(f'{stage}_loss', loss, prog_bar=True) | ||
self.log(f'{stage}_acc', acc, prog_bar=True) | ||
|
||
return loss, acc | ||
|
||
def validation_step(self, batch, batch_idx): | ||
return self._evaluate(batch, batch_idx, 'val')[0] | ||
|
||
def test_step(self, batch, batch_idx): | ||
loss, acc = self._evaluate(batch, batch_idx, 'test') | ||
self.log_dict({'test_loss': loss, 'test_acc': acc}) | ||
|
||
def configure_optimizers(self): | ||
optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr, momentum=0.9, weight_decay=5e-4) | ||
return { | ||
'optimizer': optimizer, | ||
'lr_scheduler': { | ||
'scheduler': torch.optim.lr_scheduler.OneCycleLR( | ||
optimizer, | ||
0.1, | ||
epochs=self.trainer.max_epochs, | ||
steps_per_epoch=math.ceil(45000 / self.hparams.batch_size)), | ||
'interval': 'step', | ||
} | ||
} | ||
|
||
@property | ||
def automatic_optimization(self) -> bool: | ||
return not self._manual_optimization | ||
|
||
|
||
################################# | ||
# Instantiate Data Module # | ||
################################# | ||
|
||
def instantiate_datamodule(args): | ||
train_transforms = torchvision.transforms.Compose([ | ||
torchvision.transforms.RandomCrop(32, padding=4), | ||
torchvision.transforms.RandomHorizontalFlip(), | ||
torchvision.transforms.ToTensor(), | ||
cifar10_normalization(), | ||
]) | ||
|
||
test_transforms = torchvision.transforms.Compose([ | ||
torchvision.transforms.ToTensor(), | ||
cifar10_normalization(), | ||
]) | ||
|
||
cifar10_dm = pl_bolts.datamodules.CIFAR10DataModule( | ||
batch_size=args.batch_size, | ||
train_transforms=train_transforms, | ||
test_transforms=test_transforms, | ||
val_transforms=test_transforms, | ||
) | ||
|
||
return cifar10_dm | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = ArgumentParser(description="Pipe Example") | ||
parser.add_argument("--use_ddp_sequential", action="store_true") | ||
parser = Trainer.add_argparse_args(parser) | ||
parser = pl_bolts.datamodules.CIFAR10DataModule.add_argparse_args(parser) | ||
args = parser.parse_args() | ||
|
||
assert BOLTS_AVAILABLE, "Bolts is required for this example, install it via pip install pytorch-lightning-bolts" | ||
assert FAIRSCALE_PIPE_AVAILABLE, "FairScale and PyTorch 1.6 is required for this example." | ||
|
||
cifar10_dm = instantiate_datamodule(args) | ||
|
||
plugins = None | ||
if args.use_ddp_sequential: | ||
plugins = DDPSequentialPlugin() | ||
|
||
model = LitResnet(batch_size=args.batch_size, manual_optimization=not args.automatic_optimization) | ||
|
||
trainer = pl.Trainer.from_argparse_args(args, plugins=[plugins] if plugins else None) | ||
trainer.fit(model, cifar10_dm) | ||
trainer.test(model, datamodule=cifar10_dm) | ||
|
||
if trainer.accelerator_backend.rpc_enabled: | ||
# Called at the end of trainer to ensure all processes are killed | ||
trainer.accelerator_backend.ddp_plugin.exit_rpc_process() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.