-
Notifications
You must be signed in to change notification settings - Fork 85
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
360bf09
commit 5e3288c
Showing
12 changed files
with
514 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
*.idea* | ||
*__pycache__* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# MLPerf Tiny image classification PyTorch model | ||
|
||
This is the MLPerf Tiny image classification PyTorch model. | ||
|
||
A ResNet8 model is trained on the CIFAR10 dataset available at: | ||
https://www.cs.toronto.edu/~kriz/cifar.html | ||
|
||
Model: ResNet8 | ||
Dataset: Cifar10 | ||
|
||
## Quick start | ||
|
||
Run the following commands to go through the whole training and validation process | ||
|
||
```Bash | ||
# Prepare Python venv (Python 3.7+ and pip>20 required) | ||
./prepare_training_env.sh | ||
|
||
# Download training, train model, test the model | ||
./download_cifar10_train_resnet.sh | ||
``` |
Empty file.
6 changes: 6 additions & 0 deletions
6
benchmark/training_torch/image_classification/download_cifar10_train_resnet.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
#!/bin/bash | ||
. venv/bin/activate | ||
|
||
# train ans test the model | ||
python3 train.py | ||
python3 test.py |
5 changes: 5 additions & 0 deletions
5
benchmark/training_torch/image_classification/prepare_training_env.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
#!/bin/sh | ||
|
||
python3 -m venv venv | ||
. venv/bin/activate | ||
pip3 install -r requirements.txt |
3 changes: 3 additions & 0 deletions
3
benchmark/training_torch/image_classification/requirements.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
tensorboard==2.14.1 | ||
torch==2.0.1 | ||
torchvision==0.15.2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import argparse | ||
|
||
import torch | ||
from utils.data import get_test_dataloader | ||
from utils.training import eval_training | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--model-ckpt", | ||
default="trained_models/best.pth", | ||
type=str, | ||
help="Path to model checkpoint for evaluation.", | ||
) | ||
parser.add_argument( | ||
"--batch-size", | ||
type=int, | ||
default=32, | ||
help="Batch size. Default value is 32 according to TF training procedure.", | ||
) | ||
parser.add_argument( | ||
"--data-dir", | ||
default="cifar-10-torch", | ||
type=str, | ||
help="Path to dataset (will be downloaded).", | ||
) | ||
parser.add_argument( | ||
"--workers", default=4, type=int, help="Number of data loading processes." | ||
) | ||
args = parser.parse_args() | ||
|
||
model = torch.load(args.model_ckpt) | ||
|
||
val_loader = get_test_dataloader( | ||
cifar_10_dir=args.data_dir, | ||
batch_size=args.batch_size, | ||
num_workers=args.workers, | ||
) | ||
loss_function = torch.nn.CrossEntropyLoss() | ||
|
||
accuracy = eval_training( | ||
model=model, | ||
dataloader=val_loader, | ||
loss_function=loss_function, | ||
epoch=0, | ||
log_to_tensorboard=False, | ||
writer=None, | ||
) | ||
|
||
print(f"Model {args.model_ckpt} has accuracy: {accuracy}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
import argparse | ||
|
||
import torch | ||
from torch.optim import Adam | ||
from torch.optim.lr_scheduler import LambdaLR | ||
from torch.utils.tensorboard import SummaryWriter | ||
from utils.data import get_test_dataloader | ||
from utils.data import get_training_dataloader | ||
from utils.model import Resnet8v1EEMBC | ||
from utils.training import WarmUpLR | ||
from utils.training import eval_training | ||
from utils.training import train_one_epoch | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--batch-size", | ||
type=int, | ||
default=32, | ||
help="Batch size. Default value is 32 according to TF training procedure.", | ||
) | ||
parser.add_argument( | ||
"--epochs", | ||
type=int, | ||
default=500, | ||
help="Number of epochs. Default value is 500 according to TF training procedure.", | ||
) | ||
parser.add_argument( | ||
"--warmup-epochs", | ||
type=int, | ||
default=0, | ||
help="Number of epochs for LR linear warmup.", | ||
) | ||
parser.add_argument( | ||
"--lr", | ||
default=0.001, | ||
type=float, | ||
help="Initial learning rate. Default value is 1e-3 according to TF training procedure.", | ||
) | ||
parser.add_argument( | ||
"--lr-decay", | ||
default=0.99, | ||
type=float, | ||
help="Initial learning rate. Default value is 1e-3 according to TF training procedure.", | ||
) | ||
parser.add_argument( | ||
"--data-dir", | ||
default="cifar-10-torch", | ||
type=str, | ||
help="Path to dataset (will be downloaded).", | ||
) | ||
parser.add_argument( | ||
"--workers", default=4, type=int, help="Number of data loading processes." | ||
) | ||
parser.add_argument( | ||
"--weight-decay", default=1e-4, type=float, help="Weight decay for optimizer." | ||
) | ||
parser.add_argument("--log-dir", type=str, default="trained_models") | ||
args = parser.parse_args() | ||
|
||
train_loader = get_training_dataloader( | ||
cifar_10_dir=args.data_dir, | ||
batch_size=args.batch_size, | ||
num_workers=args.workers, | ||
) | ||
val_loader = get_test_dataloader( | ||
cifar_10_dir=args.data_dir, | ||
batch_size=args.batch_size, | ||
num_workers=args.workers, | ||
) | ||
|
||
model = Resnet8v1EEMBC() | ||
if torch.cuda.is_available(): | ||
model = model.cuda() | ||
optimizer = Adam( | ||
params=model.parameters(), | ||
lr=args.lr, | ||
weight_decay=args.weight_decay, | ||
) | ||
train_scheduler = LambdaLR( | ||
optimizer=optimizer, lr_lambda=lambda epoch: args.lr_decay**epoch | ||
) | ||
warmup_scheduler = None | ||
if args.warmup_epochs: | ||
warmup_scheduler = WarmUpLR(optimizer=optimizer, total_iters=args.warmup_epochs) | ||
|
||
writer = SummaryWriter(log_dir=args.log_dir) | ||
|
||
loss_function = torch.nn.CrossEntropyLoss() | ||
|
||
best_accuracy = 0.0 | ||
for epoch in range(1, args.epochs + 1): | ||
if epoch > args.warmup_epochs: | ||
train_scheduler.step() | ||
|
||
train_one_epoch( | ||
model=model, | ||
train_dataloader=train_loader, | ||
loss_function=loss_function, | ||
optimizer=optimizer, | ||
epoch=epoch, | ||
writer=writer, | ||
warmup_scheduler=warmup_scheduler, | ||
warmup_epochs=args.warmup_epochs, | ||
train_scheduler=train_scheduler, | ||
) | ||
|
||
accuracy = eval_training( | ||
model=model, | ||
dataloader=val_loader, | ||
loss_function=loss_function, | ||
epoch=epoch, | ||
log_to_tensorboard=True, | ||
writer=writer, | ||
) | ||
|
||
if best_accuracy < accuracy: | ||
weights_path = f"{args.log_dir}/best.pth" | ||
print(f"saving weights file to {weights_path}") | ||
torch.save(model, weights_path) | ||
best_accuracy = accuracy | ||
continue | ||
|
||
writer.flush() | ||
|
||
writer.close() |
Empty file.
97 changes: 97 additions & 0 deletions
97
benchmark/training_torch/image_classification/utils/data.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import torch | ||
import torchvision | ||
import torchvision.transforms as transforms | ||
from torch.utils.data import DataLoader | ||
|
||
|
||
def get_training_dataloader( | ||
cifar_10_dir: str = "cifar-10-torch", | ||
batch_size: int = 16, | ||
num_workers: int = 2, | ||
shuffle: bool = True, | ||
) -> torch.utils.data.DataLoader: | ||
"""Create DataLoader for training data. | ||
Parameters | ||
---------- | ||
cifar_10_dir: str | ||
Path to CIFAR10 data root in torchvision format. | ||
batch_size: int | ||
Batch size for dataloader. | ||
num_workers: int | ||
Number of subprocesses for data loading. | ||
shuffle: bool | ||
Flag for shuffling training data. | ||
Returns | ||
------- | ||
torch.utils.data.DataLoader | ||
DataLoader for training data. | ||
""" | ||
transform_train = transforms.Compose( | ||
[ | ||
transforms.RandomCrop(32, padding=4), | ||
transforms.RandomHorizontalFlip(), | ||
transforms.RandomRotation(15), | ||
transforms.ToTensor(), | ||
] | ||
) | ||
cifar10_training = torchvision.datasets.CIFAR10( | ||
root=cifar_10_dir, | ||
train=True, | ||
download=True, | ||
transform=transform_train, | ||
) | ||
cifar10_training_loader = DataLoader( | ||
cifar10_training, | ||
shuffle=shuffle, | ||
num_workers=num_workers, | ||
batch_size=batch_size, | ||
) | ||
|
||
return cifar10_training_loader | ||
|
||
|
||
def get_test_dataloader( | ||
cifar_10_dir="cifar-10-torch", | ||
batch_size=16, | ||
num_workers=2, | ||
shuffle=True, | ||
): | ||
"""Create DataLoader for test data. | ||
Parameters | ||
---------- | ||
cifar_10_dir: str | ||
Path to CIFAR10 data root in torchvision format. | ||
batch_size: int | ||
Batch size for dataloader. | ||
num_workers: int | ||
Number of subprocesses for data loading. | ||
shuffle: bool | ||
Flag for shuffling training data. | ||
Returns | ||
------- | ||
torch.utils.data.DataLoader | ||
DataLoader for test data. | ||
""" | ||
transform_test = transforms.Compose( | ||
[ | ||
transforms.ToTensor(), | ||
] | ||
) | ||
cifar10_test = torchvision.datasets.CIFAR10( | ||
root=cifar_10_dir, | ||
train=False, | ||
download=True, | ||
transform=transform_test, | ||
) | ||
cifar10_test_loader = DataLoader( | ||
cifar10_test, | ||
shuffle=shuffle, | ||
num_workers=num_workers, | ||
batch_size=batch_size, | ||
) | ||
|
||
return cifar10_test_loader |
75 changes: 75 additions & 0 deletions
75
benchmark/training_torch/image_classification/utils/model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import torch | ||
from torch import nn | ||
from torch.nn import functional as F | ||
|
||
|
||
class ResNetBlock(nn.Module): | ||
def __init__( | ||
self, | ||
in_channels: int, | ||
out_channels: int, | ||
stride: int = 1, | ||
): | ||
super().__init__() | ||
self.block = nn.Sequential( | ||
nn.Conv2d( | ||
in_channels=in_channels, | ||
out_channels=out_channels, | ||
kernel_size=3, | ||
padding=1, | ||
bias=True, | ||
stride=stride, | ||
), | ||
nn.BatchNorm2d(num_features=out_channels), | ||
nn.ReLU(inplace=True), | ||
nn.Conv2d( | ||
in_channels=out_channels, | ||
out_channels=out_channels, | ||
kernel_size=3, | ||
padding=1, | ||
bias=True, | ||
), | ||
nn.BatchNorm2d(num_features=out_channels), | ||
) | ||
if in_channels == out_channels: | ||
self.residual = nn.Identity() | ||
else: | ||
self.residual = nn.Conv2d( | ||
in_channels=in_channels, | ||
out_channels=out_channels, | ||
kernel_size=1, | ||
stride=stride, | ||
) | ||
|
||
def forward(self, inputs): | ||
x = self.block(inputs) | ||
y = self.residual(inputs) | ||
return F.relu(x + y) | ||
|
||
|
||
class Resnet8v1EEMBC(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.stem = nn.Sequential( | ||
nn.Conv2d( | ||
in_channels=3, out_channels=16, kernel_size=3, padding=1, bias=True | ||
), | ||
nn.BatchNorm2d(num_features=16), | ||
nn.ReLU(inplace=True), | ||
) | ||
|
||
self.first_stack = ResNetBlock(in_channels=16, out_channels=16, stride=1) | ||
self.second_stack = ResNetBlock(in_channels=16, out_channels=32, stride=2) | ||
self.third_stack = ResNetBlock(in_channels=32, out_channels=64, stride=2) | ||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | ||
self.fc = nn.Linear(in_features=64, out_features=10) | ||
|
||
def forward(self, inputs): | ||
x = self.stem(inputs) | ||
x = self.first_stack(x) | ||
x = self.second_stack(x) | ||
x = self.third_stack(x) | ||
x = self.avgpool(x) | ||
x = torch.flatten(x, 1) | ||
x = self.fc(x) | ||
return x |
Oops, something went wrong.