From 94fbaa518de6b36af89e1bca37a4b534dabb3668 Mon Sep 17 00:00:00 2001 From: Pariente Manuel Date: Fri, 3 Apr 2020 15:43:14 +0200 Subject: [PATCH] Add parity test for simple RNN (#1351) * Add parity test for simple RNN * Update test_rnn_parity.py Co-authored-by: William Falcon --- benchmarks/test_rnn_parity.py | 155 ++++++++++++++++++++++++++++++++++ 1 file changed, 155 insertions(+) create mode 100644 benchmarks/test_rnn_parity.py diff --git a/benchmarks/test_rnn_parity.py b/benchmarks/test_rnn_parity.py new file mode 100644 index 00000000000000..34549e2dc3a9f8 --- /dev/null +++ b/benchmarks/test_rnn_parity.py @@ -0,0 +1,155 @@ +import time + +import numpy as np +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import Dataset, DataLoader + +from pytorch_lightning import Trainer, LightningModule + + +class AverageDataset(Dataset): + def __init__(self, dataset_len=300, sequence_len=100): + self.dataset_len = dataset_len + self.sequence_len = sequence_len + self.input_seq = torch.randn(dataset_len, sequence_len, 10) + top, bottom = self.input_seq.chunk(2, -1) + self.output_seq = top + bottom.roll(shifts=1, dims=-1) + + def __len__(self): + return self.dataset_len + + def __getitem__(self, item): + return self.input_seq[item], self.output_seq[item] + + +class ParityRNN(LightningModule): + def __init__(self): + super(ParityRNN, self).__init__() + self.rnn = nn.LSTM(10, 20, batch_first=True) + self.linear_out = nn.Linear(in_features=20, out_features=5) + + def forward(self, x): + seq, last = self.rnn(x) + return self.linear_out(seq) + + def training_step(self, batch, batch_nb): + x, y = batch + y_hat = self(x) + loss = F.mse_loss(y_hat, y) + return {'loss': loss} + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=0.02) + + def train_dataloader(self): + return DataLoader(AverageDataset(), batch_size=30) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +def test_pytorch_parity(tmpdir): + """ + Verify that the same pytorch and lightning models achieve the same results + :param tmpdir: + :return: + """ + num_epochs = 2 + num_rums = 3 + + lightning_outs, pl_times = lightning_loop(ParityRNN, num_rums, num_epochs) + manual_outs, pt_times = vanilla_loop(ParityRNN, num_rums, num_epochs) + # make sure the losses match exactly to 5 decimal places + for pl_out, pt_out in zip(lightning_outs, manual_outs): + np.testing.assert_almost_equal(pl_out, pt_out, 8) + + +def set_seed(seed): + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + +def vanilla_loop(MODEL, num_runs=10, num_epochs=10): + """ + Returns an array with the last loss from each epoch for each run + """ + device = torch.device('cuda' if torch.cuda.is_available() else "cpu") + errors = [] + times = [] + + for i in range(num_runs): + time_start = time.perf_counter() + + # set seed + seed = i + set_seed(seed) + + # init model parts + model = MODEL() + dl = model.train_dataloader() + optimizer = model.configure_optimizers() + + # model to GPU + model = model.to(device) + + epoch_losses = [] + for epoch in range(num_epochs): + + # run through full training set + for j, batch in enumerate(dl): + x, y = batch + x = x.cuda(0) + y = y.cuda(0) + batch = (x, y) + + loss_dict = model.training_step(batch, j) + loss = loss_dict['loss'] + loss.backward() + optimizer.step() + optimizer.zero_grad() + + # track last epoch loss + epoch_losses.append(loss.item()) + + time_end = time.perf_counter() + times.append(time_end - time_start) + + errors.append(epoch_losses[-1]) + + return errors, times + + +def lightning_loop(MODEL, num_runs=10, num_epochs=10): + errors = [] + times = [] + + for i in range(num_runs): + time_start = time.perf_counter() + + # set seed + seed = i + set_seed(seed) + + # init model parts + model = MODEL() + trainer = Trainer( + max_epochs=num_epochs, + show_progress_bar=False, + weights_summary=None, + gpus=1, + early_stop_callback=False, + checkpoint_callback=False, + distributed_backend='dp', + ) + trainer.fit(model) + + final_loss = trainer.running_loss.last().item() + errors.append(final_loss) + + time_end = time.perf_counter() + times.append(time_end - time_start) + + return errors, times