Skip to content

Commit

Permalink
[tune] a tiny ptl example (#11497)
Browse files Browse the repository at this point in the history
  • Loading branch information
richardliaw authored and Alex Wu committed Oct 23, 2020
1 parent 2d9b735 commit 395ddb0
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 7 deletions.
10 changes: 5 additions & 5 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ matrix:
env:
- RLLIB_TESTING=1 RLLIB_QUICK_TRAIN_AND_MISC_TESTS=1
# TODO (sven): Remove this after fixing rllib tests num_cpus.
- RAY_USE_MULTIPROCESSING_CPU_COUNT=1
- RAY_USE_MULTIPROCESSING_CPU_COUNT=1
- PYTHON=3.6
- TF_VERSION=2.1.0
- TFP_VERSION=0.8
Expand All @@ -297,7 +297,7 @@ matrix:
env:
- RLLIB_TESTING=1 RLLIB_EXAMPLE_DIR_TESTS=1
# TODO (sven): Remove this after fixing rllib tests num_cpus.
- RAY_USE_MULTIPROCESSING_CPU_COUNT=1
- RAY_USE_MULTIPROCESSING_CPU_COUNT=1
- PYTHON=3.6
- TF_VERSION=2.1.0
- TFP_VERSION=0.8
Expand All @@ -318,7 +318,7 @@ matrix:
env:
- RLLIB_TESTING=1 RLLIB_TESTS_DIR_TESTS_A_TO_L=1
# TODO (sven): Remove this after fixing rllib tests num_cpus.
- RAY_USE_MULTIPROCESSING_CPU_COUNT=1
- RAY_USE_MULTIPROCESSING_CPU_COUNT=1
- PYTHON=3.6
- TF_VERSION=2.1.0
- TFP_VERSION=0.8
Expand All @@ -336,7 +336,7 @@ matrix:
env:
- RLLIB_TESTING=1 RLLIB_TESTS_DIR_TESTS_M_TO_Z=1
# TODO (sven): Remove this after fixing rllib tests num_cpus.
- RAY_USE_MULTIPROCESSING_CPU_COUNT=1
- RAY_USE_MULTIPROCESSING_CPU_COUNT=1
- PYTHON=3.6
- TF_VERSION=2.1.0
- TFP_VERSION=0.8
Expand All @@ -357,7 +357,7 @@ matrix:
- PYTHON=3.6
- TF_VERSION=2.2.0
- TFP_VERSION=0.8
- TORCH_VERSION=1.5
- TORCH_VERSION=1.6
- PYTHONWARNINGS=ignore
install:
- . ./ci/travis/ci.sh init RAY_CI_TUNE_AFFECTED
Expand Down
7 changes: 6 additions & 1 deletion doc/source/tune/examples/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,15 @@ PyTorch Examples
~~~~~~~~~~~~~~~~

- :doc:`/tune/examples/mnist_pytorch`: Converts the PyTorch MNIST example to use Tune with the function-based API. Also shows how to easily convert something relying on argparse to use Tune.
- :doc:`/tune/examples/mnist_pytorch_lightning`: Uses `Pytorch Lightning <https://github.com/PyTorchLightning/pytorch-lightning>`_ to train a MNIST model. This example utilizes the Ray Tune-provided :ref:`PyTorch Lightning callbacks <tune-integration-pytorch-lightning>`. See also :ref:`this tutorial for a full walkthrough <tune-pytorch-lightning>`.
- :doc:`/tune/examples/mnist_pytorch_trainable`: Converts the PyTorch MNIST example to use Tune with Trainable API. Also uses the HyperBandScheduler and checkpoints the model at the end.
- :doc:`/tune/examples/ddp_mnist_torch`: An example showing how to use DistributedDataParallel with Ray Tune. This enables both distributed training and distributed hyperparameter tuning.

Pytorch Lightning Examples
~~~~~~~~~~~~~~~~~~~~~~~~~~

- :doc:`/tune/examples/mnist_ptl_mini`: A minimal example of using `Pytorch Lightning <https://github.com/PyTorchLightning/pytorch-lightning>`_ to train a MNIST model. This example utilizes the Ray Tune-provided :ref:`PyTorch Lightning callbacks <tune-integration-pytorch-lightning>`. See also :ref:`this tutorial for a full walkthrough <tune-pytorch-lightning>`.
- :doc:`/tune/examples/mnist_pytorch_lightning`: A comprehensive example using `Pytorch Lightning <https://github.com/PyTorchLightning/pytorch-lightning>`_ to train a MNIST model. This example showcases how to use various search optimization techniques. It utilizes the Ray Tune-provided :ref:`PyTorch Lightning callbacks <tune-integration-pytorch-lightning>`. See also :ref:`this tutorial for a full walkthrough <tune-pytorch-lightning>`.


XGBoost Example
~~~~~~~~~~~~~~~
Expand Down
6 changes: 6 additions & 0 deletions doc/source/tune/examples/mnist_ptl_mini.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
:orphan:

mnist_ptl_mini
~~~~~~~~~~~~~~

.. literalinclude:: /../../python/ray/tune/examples/mnist_ptl_mini.py
9 changes: 9 additions & 0 deletions python/ray/tune/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,15 @@ py_test(
args = ["--smoke-test"]
)

py_test(
name = "mnist_ptl_mini",
size = "medium",
srcs = ["examples/mnist_ptl_mini.py"],
deps = [":tune_lib"],
tags = ["exclusive", "example", "pytorch"],
args = ["--smoke-test"]
)

py_test(
name = "mnist_pytorch_trainable",
size = "small",
Expand Down
117 changes: 117 additions & 0 deletions python/ray/tune/examples/mnist_ptl_mini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import torch
from torch.nn import functional as F
import pytorch_lightning as pl
from pl_bolts.datamodules import MNISTDataModule
import os
from ray.tune.integration.pytorch_lightning import TuneReportCallback

import tempfile
from ray import tune


class LightningMNISTClassifier(pl.LightningModule):
def __init__(self, config, data_dir=None):
super(LightningMNISTClassifier, self).__init__()

self.data_dir = data_dir or os.getcwd()
self.lr = config["lr"]
layer_1, layer_2 = config["layer_1"], config["layer_2"]

# mnist images are (1, 28, 28) (channels, width, height)
self.layer_1 = torch.nn.Linear(28 * 28, layer_1)
self.layer_2 = torch.nn.Linear(layer_1, layer_2)
self.layer_3 = torch.nn.Linear(layer_2, 10)
self.accuracy = pl.metrics.Accuracy()

def forward(self, x):
batch_size, channels, width, height = x.size()
x = x.view(batch_size, -1)
x = self.layer_1(x)
x = torch.relu(x)
x = self.layer_2(x)
x = torch.relu(x)
x = self.layer_3(x)
x = torch.log_softmax(x, dim=1)
return x

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.lr)

def training_step(self, train_batch, batch_idx):
x, y = train_batch
logits = self.forward(x)
loss = F.nll_loss(logits, y)
acc = self.accuracy(logits, y)
self.log("ptl/train_loss", loss)
self.log("ptl/train_accuracy", acc)
return loss

def validation_step(self, val_batch, batch_idx):
x, y = val_batch
logits = self.forward(x)
loss = F.nll_loss(logits, y)
acc = self.accuracy(logits, y)
return {"val_loss": loss, "val_accuracy": acc}

def validation_epoch_end(self, outputs):
avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean()
self.log("ptl/val_loss", avg_loss)
self.log("ptl/val_accuracy", avg_acc)


def train_mnist_tune(config, data_dir=None, num_epochs=10, num_gpus=0):
model = LightningMNISTClassifier(config, data_dir)
dm = MNISTDataModule(
data_dir=data_dir, num_workers=1, batch_size=config["batch_size"])
metrics = {"loss": "ptl/val_loss", "acc": "ptl/val_accuracy"}
trainer = pl.Trainer(
max_epochs=num_epochs,
gpus=num_gpus,
progress_bar_refresh_rate=0,
callbacks=[TuneReportCallback(metrics, on="validation_end")])
trainer.fit(model, dm)


def tune_mnist(num_samples=10, num_epochs=10, gpus_per_trial=0):
data_dir = os.path.join(tempfile.gettempdir(), "mnist_data_")
# Download data
MNISTDataModule(data_dir=data_dir).prepare_data()

config = {
"layer_1": tune.choice([32, 64, 128]),
"layer_2": tune.choice([64, 128, 256]),
"lr": tune.loguniform(1e-4, 1e-1),
"batch_size": tune.choice([32, 64, 128]),
}

trainable = tune.with_parameters(
train_mnist_tune,
data_dir=data_dir,
num_epochs=num_epochs,
num_gpus=gpus_per_trial)
tune.run(
trainable,
resources_per_trial={
"cpu": 1,
"gpu": gpus_per_trial
},
metric="loss",
mode="min",
config=config,
num_samples=num_samples,
name="tune_mnist")


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing")
args, _ = parser.parse_known_args()

if args.smoke_test:
tune_mnist(num_samples=1, num_epochs=1, gpus_per_trial=0)
else:
tune_mnist(num_samples=10, num_epochs=10, gpus_per_trial=0)
3 changes: 2 additions & 1 deletion python/ray/tune/examples/pbt_dcgan_mnist/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,8 @@ def train(netD, netG, optimG, optimD, criterion, dataloader, iteration, device,
netD.zero_grad()
real_cpu = data[0].to(device)
b_size = real_cpu.size(0)
label = torch.full((b_size, ), real_label, device=device)
label = torch.full(
(b_size, ), real_label, dtype=torch.float, device=device)
output = netD(real_cpu).view(-1)
errD_real = criterion(output, label)
errD_real.backward()
Expand Down
1 change: 1 addition & 0 deletions python/requirements_tune.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ nevergrad
optuna
pytest-remotedata>=0.3.1
pytorch-lightning
pytorch-lightning-bolts
scikit-optimize
sigopt
smart_open
Expand Down

0 comments on commit 395ddb0

Please sign in to comment.