Skip to content
This repository has been archived by the owner on Sep 19, 2022. It is now read-only.

Adding distributed mnist example with summaries #118

Merged
merged 1 commit into from
Jan 3, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/katib/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ FROM pytorch/pytorch:0.4_cuda9_cudnn7

RUN pip install tensorboardX
ADD . /opt/pytorch_dist_mnist
ENTRYPOINT ["python", "/opt/pytorch_dist_mnist/mnist_with_summary.py"]
ENTRYPOINT ["python", "/opt/pytorch_dist_mnist/dist_mnist_with_summary.py"]
132 changes: 132 additions & 0 deletions examples/katib/dist_mnist_with_summary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from __future__ import print_function
import argparse
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from tensorboardX import SummaryWriter
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
help='SGD momentum (default: 0.5)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--dir', default='logs', metavar='L',
help='directory where summary logs are stored')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)
if args.cuda:
torch.cuda.manual_seed(args.seed)


kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.batch_size, shuffle=True, **kwargs)


class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)

def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x)

model = Net()
if args.cuda:
model.cuda()

print('Learning rate: {} Momentum: {} Logs dir: {}'.format(args.lr, args.momentum, args.dir))
writer = SummaryWriter(args.dir)
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)

def average_gradients():
world_size = dist.get_world_size()
for param in model.parameters():
dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM)
param.grad.data /= float(world_size)

def train(epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
if args.cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data), Variable(target)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
average_gradients()
optimizer.step()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tloss={:.4f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
niter = epoch*len(train_loader)+batch_idx
writer.add_scalar('loss', loss.item(), niter)

def test(epoch):
model.eval()
test_loss = 0
correct = 0
for data, target in test_loader:
if args.cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data, volatile=True), Variable(target)
output = model(data)
test_loss += F.nll_loss(output, target, size_average=False).item() # sum up batch loss
pred = output.data.max(1)[1] # get the index of the max log-probability
correct += pred.eq(target.data).cpu().sum()

test_loss /= len(test_loader.dataset)
print('\naccuracy={:.4f}\n'.format(float(correct) / len(test_loader.dataset)))
writer.add_scalar('accuracy', float(correct) / len(test_loader.dataset), epoch)


def init_processes(backend='tcp'):
""" Initialize the distributed environment. """
dist.init_process_group(backend)

if __name__ == "__main__":
init_processes()
for epoch in range(1, args.epochs + 1):
train(epoch)
test(epoch)
12 changes: 10 additions & 2 deletions examples/katib/pytorch_job_mnist.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,19 @@ metadata:
name: "pytorch-mnist-with-summary"
spec:
pytorchReplicaSpecs:
Worker:
Master:
replicas: 1
restartPolicy: OnFailure
template:
spec:
containers:
- name: pytorch
image: gcr.io/kubeflow-ci/pytorch-mnist-with-summary:0.4
image: gcr.io/kubeflow-ci/pytorch-mnist-with-summary:1.0
Worker:
replicas: 3
restartPolicy: OnFailure
template:
spec:
containers:
- name: pytorch
image: gcr.io/kubeflow-ci/pytorch-mnist-with-summary:1.0