Skip to content

Commit

Permalink
Add training on PyTorch (#141)
Browse files Browse the repository at this point in the history
  • Loading branch information
andreysher authored Oct 9, 2023
1 parent 360bf09 commit bb3e0aa
Show file tree
Hide file tree
Showing 12 changed files with 514 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*.idea*
*__pycache__*
21 changes: 21 additions & 0 deletions benchmark/training_torch/image_classification/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# MLPerf Tiny image classification PyTorch model

This is the MLPerf Tiny image classification PyTorch model.

A ResNet8 model is trained on the CIFAR10 dataset available at:
https://www.cs.toronto.edu/~kriz/cifar.html

Model: ResNet8
Dataset: Cifar10

## Quick start

Run the following commands to go through the whole training and validation process

```Bash
# Prepare Python venv (Python 3.7+ and pip>20 required)
./prepare_training_env.sh

# Download training, train model, test the model
./download_cifar10_train_resnet.sh
```
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/bin/bash
. venv/bin/activate

# train ans test the model
python3 train.py
python3 test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/bin/sh

python3 -m venv venv
. venv/bin/activate
pip3 install -r requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
tensorboard==2.14.1
torch==2.0.1
torchvision==0.15.2
50 changes: 50 additions & 0 deletions benchmark/training_torch/image_classification/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import argparse

import torch
from utils.data import get_test_dataloader
from utils.training import eval_training

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model-ckpt",
default="trained_models/best.pth",
type=str,
help="Path to model checkpoint for evaluation.",
)
parser.add_argument(
"--batch-size",
type=int,
default=32,
help="Batch size. Default value is 32 according to TF training procedure.",
)
parser.add_argument(
"--data-dir",
default="cifar-10-torch",
type=str,
help="Path to dataset (will be downloaded).",
)
parser.add_argument(
"--workers", default=4, type=int, help="Number of data loading processes."
)
args = parser.parse_args()

model = torch.load(args.model_ckpt)

val_loader = get_test_dataloader(
cifar_10_dir=args.data_dir,
batch_size=args.batch_size,
num_workers=args.workers,
)
loss_function = torch.nn.CrossEntropyLoss()

accuracy = eval_training(
model=model,
dataloader=val_loader,
loss_function=loss_function,
epoch=0,
log_to_tensorboard=False,
writer=None,
)

print(f"Model {args.model_ckpt} has accuracy: {accuracy}")
126 changes: 126 additions & 0 deletions benchmark/training_torch/image_classification/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import argparse

import torch
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.tensorboard import SummaryWriter
from utils.data import get_test_dataloader
from utils.data import get_training_dataloader
from utils.model import Resnet8v1EEMBC
from utils.training import WarmUpLR
from utils.training import eval_training
from utils.training import train_one_epoch

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--batch-size",
type=int,
default=32,
help="Batch size. Default value is 32 according to TF training procedure.",
)
parser.add_argument(
"--epochs",
type=int,
default=500,
help="Number of epochs. Default value is 500 according to TF training procedure.",
)
parser.add_argument(
"--warmup-epochs",
type=int,
default=0,
help="Number of epochs for LR linear warmup.",
)
parser.add_argument(
"--lr",
default=0.001,
type=float,
help="Initial learning rate. Default value is 1e-3 according to TF training procedure.",
)
parser.add_argument(
"--lr-decay",
default=0.99,
type=float,
help="Initial learning rate. Default value is 1e-3 according to TF training procedure.",
)
parser.add_argument(
"--data-dir",
default="cifar-10-torch",
type=str,
help="Path to dataset (will be downloaded).",
)
parser.add_argument(
"--workers", default=4, type=int, help="Number of data loading processes."
)
parser.add_argument(
"--weight-decay", default=1e-4, type=float, help="Weight decay for optimizer."
)
parser.add_argument("--log-dir", type=str, default="trained_models")
args = parser.parse_args()

train_loader = get_training_dataloader(
cifar_10_dir=args.data_dir,
batch_size=args.batch_size,
num_workers=args.workers,
)
val_loader = get_test_dataloader(
cifar_10_dir=args.data_dir,
batch_size=args.batch_size,
num_workers=args.workers,
)

model = Resnet8v1EEMBC()
if torch.cuda.is_available():
model = model.cuda()
optimizer = Adam(
params=model.parameters(),
lr=args.lr,
weight_decay=args.weight_decay,
)
train_scheduler = LambdaLR(
optimizer=optimizer, lr_lambda=lambda epoch: args.lr_decay**epoch
)
warmup_scheduler = None
if args.warmup_epochs:
warmup_scheduler = WarmUpLR(optimizer=optimizer, total_iters=args.warmup_epochs)

writer = SummaryWriter(log_dir=args.log_dir)

loss_function = torch.nn.CrossEntropyLoss()

best_accuracy = 0.0
for epoch in range(1, args.epochs + 1):
if epoch > args.warmup_epochs:
train_scheduler.step()

train_one_epoch(
model=model,
train_dataloader=train_loader,
loss_function=loss_function,
optimizer=optimizer,
epoch=epoch,
writer=writer,
warmup_scheduler=warmup_scheduler,
warmup_epochs=args.warmup_epochs,
train_scheduler=train_scheduler,
)

accuracy = eval_training(
model=model,
dataloader=val_loader,
loss_function=loss_function,
epoch=epoch,
log_to_tensorboard=True,
writer=writer,
)

if best_accuracy < accuracy:
weights_path = f"{args.log_dir}/best.pth"
print(f"saving weights file to {weights_path}")
torch.save(model, weights_path)
best_accuracy = accuracy
continue

writer.flush()

writer.close()
Empty file.
97 changes: 97 additions & 0 deletions benchmark/training_torch/image_classification/utils/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader


def get_training_dataloader(
cifar_10_dir: str = "cifar-10-torch",
batch_size: int = 16,
num_workers: int = 2,
shuffle: bool = True,
) -> torch.utils.data.DataLoader:
"""Create DataLoader for training data.
Parameters
----------
cifar_10_dir: str
Path to CIFAR10 data root in torchvision format.
batch_size: int
Batch size for dataloader.
num_workers: int
Number of subprocesses for data loading.
shuffle: bool
Flag for shuffling training data.
Returns
-------
torch.utils.data.DataLoader
DataLoader for training data.
"""
transform_train = transforms.Compose(
[
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ToTensor(),
]
)
cifar10_training = torchvision.datasets.CIFAR10(
root=cifar_10_dir,
train=True,
download=True,
transform=transform_train,
)
cifar10_training_loader = DataLoader(
cifar10_training,
shuffle=shuffle,
num_workers=num_workers,
batch_size=batch_size,
)

return cifar10_training_loader


def get_test_dataloader(
cifar_10_dir="cifar-10-torch",
batch_size=16,
num_workers=2,
shuffle=True,
):
"""Create DataLoader for test data.
Parameters
----------
cifar_10_dir: str
Path to CIFAR10 data root in torchvision format.
batch_size: int
Batch size for dataloader.
num_workers: int
Number of subprocesses for data loading.
shuffle: bool
Flag for shuffling training data.
Returns
-------
torch.utils.data.DataLoader
DataLoader for test data.
"""
transform_test = transforms.Compose(
[
transforms.ToTensor(),
]
)
cifar10_test = torchvision.datasets.CIFAR10(
root=cifar_10_dir,
train=False,
download=True,
transform=transform_test,
)
cifar10_test_loader = DataLoader(
cifar10_test,
shuffle=shuffle,
num_workers=num_workers,
batch_size=batch_size,
)

return cifar10_test_loader
75 changes: 75 additions & 0 deletions benchmark/training_torch/image_classification/utils/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import torch
from torch import nn
from torch.nn import functional as F


class ResNetBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
stride: int = 1,
):
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
padding=1,
bias=True,
stride=stride,
),
nn.BatchNorm2d(num_features=out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
padding=1,
bias=True,
),
nn.BatchNorm2d(num_features=out_channels),
)
if in_channels == out_channels:
self.residual = nn.Identity()
else:
self.residual = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=stride,
)

def forward(self, inputs):
x = self.block(inputs)
y = self.residual(inputs)
return F.relu(x + y)


class Resnet8v1EEMBC(nn.Module):
def __init__(self):
super().__init__()
self.stem = nn.Sequential(
nn.Conv2d(
in_channels=3, out_channels=16, kernel_size=3, padding=1, bias=True
),
nn.BatchNorm2d(num_features=16),
nn.ReLU(inplace=True),
)

self.first_stack = ResNetBlock(in_channels=16, out_channels=16, stride=1)
self.second_stack = ResNetBlock(in_channels=16, out_channels=32, stride=2)
self.third_stack = ResNetBlock(in_channels=32, out_channels=64, stride=2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(in_features=64, out_features=10)

def forward(self, inputs):
x = self.stem(inputs)
x = self.first_stack(x)
x = self.second_stack(x)
x = self.third_stack(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
Loading

0 comments on commit bb3e0aa

Please sign in to comment.