Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parity test #1284

Merged
merged 27 commits into from
Mar 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .drone.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ steps:
- pip install -r ./tests/requirements.txt --user
- pip list
- python -c "import torch ; print(' & '.join([torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]) if torch.cuda.is_available() else 'only CPU')"
- coverage run --source pytorch_lightning -m py.test pytorch_lightning tests -v --doctest-modules # --flake8
- coverage run --source pytorch_lightning -m py.test pytorch_lightning tests benchmarks -v --doctest-modules # --flake8
- coverage report
- codecov --token $CODECOV_TOKEN # --pr $DRONE_PULL_REQUEST --build $DRONE_BUILD_NUMBER --branch $DRONE_BRANCH --commit $DRONE_COMMIT --tag $DRONE_TAG
- python tests/collect_env_details.py
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added parity test between a vanilla MNIST model and lightning model ([#1284](https://github.com/PyTorchLightning/pytorch-lightning/pull/1284))
- Added Reinforcement Learning - Deep Q-network (DQN) lightning example ([#1232](https://github.com/PyTorchLightning/pytorch-lightning/pull/1232))
- Added support for hierarchical `dict` ([#1152](https://github.com/PyTorchLightning/pytorch-lightning/pull/1152))
- Added `TrainsLogger` class ([#1122](https://github.com/PyTorchLightning/pytorch-lightning/pull/1122))
Expand Down
Empty file added benchmarks/__init__.py
Empty file.
151 changes: 151 additions & 0 deletions benchmarks/test_trainer_parity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import os
import time

import numpy as np
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST

from pytorch_lightning import Trainer, LightningModule


class ParityMNIST(LightningModule):

def __init__(self):
super(ParityMNIST, self).__init__()
self.c_d1 = nn.Linear(in_features=28 * 28, out_features=128)
self.c_d1_bn = nn.BatchNorm1d(128)
self.c_d1_drop = nn.Dropout(0.3)
self.c_d2 = nn.Linear(in_features=128, out_features=10)

def forward(self, x):
x = x.view(x.size(0), -1)
x = self.c_d1(x)
x = torch.tanh(x)
x = self.c_d1_bn(x)
x = self.c_d1_drop(x)
x = self.c_d2(x)
return x

def training_step(self, batch, batch_nb):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
return {'loss': loss}

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)

def train_dataloader(self):
return DataLoader(MNIST(os.getcwd(), train=True, download=True,
transform=transforms.ToTensor()), batch_size=32)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_pytorch_parity(tmpdir):
"""
Verify that the same pytorch and lightning models achieve the same results
:param tmpdir:
:return:
"""
num_epochs = 2
num_rums = 3
lightning_outs, pl_times = lightning_loop(ParityMNIST, num_rums, num_epochs)
manual_outs, pt_times = vanilla_loop(ParityMNIST, num_rums, num_epochs)

# make sure the losses match exactly to 5 decimal places
for pl_out, pt_out in zip(lightning_outs, manual_outs):
np.testing.assert_almost_equal(pl_out, pt_out, 5)


def set_seed(seed):
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)


def vanilla_loop(MODEL, num_runs=10, num_epochs=10):
"""
Returns an array with the last loss from each epoch for each run
"""
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
errors = []
times = []

for i in range(num_runs):
time_start = time.perf_counter()

# set seed
seed = i
set_seed(seed)

# init model parts
model = MODEL()
dl = model.train_dataloader()
optimizer = model.configure_optimizers()

# model to GPU
model = model.to(device)

epoch_losses = []
for epoch in range(num_epochs):

# run through full training set
for j, batch in enumerate(dl):
x, y = batch
x = x.cuda(0)
y = y.cuda(0)
batch = (x, y)

loss_dict = model.training_step(batch, j)
loss = loss_dict['loss']
loss.backward()
optimizer.step()
optimizer.zero_grad()

# track last epoch loss
epoch_losses.append(loss.item())

time_end = time.perf_counter()
times.append(time_end - time_start)

errors.append(epoch_losses[-1])

return errors, times


def lightning_loop(MODEL, num_runs=10, num_epochs=10):
errors = []
times = []

for i in range(num_runs):
time_start = time.perf_counter()

# set seed
seed = i
set_seed(seed)

# init model parts
model = MODEL()
trainer = Trainer(
max_epochs=num_epochs,
show_progress_bar=False,
weights_summary=None,
gpus=1,
early_stop_callback=False,
checkpoint_callback=False
)
trainer.fit(model)

final_loss = trainer.running_loss.last().item()
errors.append(final_loss)

time_end = time.perf_counter()
times.append(time_end - time_start)

return errors, times