Skip to content

Commit

Permalink
Add onnx export notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
jrrodri committed Nov 25, 2023
1 parent 7942d3f commit a116f17
Show file tree
Hide file tree
Showing 7 changed files with 297 additions and 304 deletions.
26 changes: 20 additions & 6 deletions abraia/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .multiple import Multiple, tempdir

import torch
import torch.backends.cudnn as cudnn
import torchvision
from torchvision import models, transforms

import os
Expand All @@ -16,7 +16,7 @@
multiple = Multiple()


cudnn.benchmark = True
torch.backends.cudnn.benchmark = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


Expand Down Expand Up @@ -66,13 +66,15 @@ def create_model(class_names, pretrained=True):
param.requires_grad = False
num_ftrs = model.fc.in_features
model.fc = torch.nn.Linear(num_ftrs, len(class_names))
model.to(device)
print('device', device)
return model


def save_model(path, model, device='cpu'):
model.to(device)
src = os.path.join(tempdir, path)
os.makedirs(os.path.dirname(src), exist_ok=True)
model.to(device)
torch.save(model.state_dict(), src)
multiple.upload_file(src, path)

Expand Down Expand Up @@ -108,7 +110,13 @@ def load_classes(path):
# License: BSD
# Author: Sasank Chilamkurthy

def train_model(model, dataloaders, criterion, optimizer, scheduler, num_epochs=25):
def train_model(model, dataloaders, criterion=None, optimizer=None, scheduler=None, num_epochs=25):
criterion = criterion or torch.nn.CrossEntropyLoss()
# Observe that only parameters of final layer are being optimized as opposed to before.
optimizer = optimizer or torch.optim.SGD(model.fc.parameters(), lr=0.001, momentum=0.9)
# Decay LR by a factor of 0.1 every 7 epochs
scheduler = torch.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

since = time.time()

best_model_wts = copy.deepcopy(model.state_dict())
Expand Down Expand Up @@ -188,8 +196,14 @@ def imshow(inp, title=None):
plt.pause(0.001) # pause a bit so that plots are updated


def visualize_model(model, dataloaders, num_images=6):
dataloader = dataloaders['val']
def visualize_data(dataloader):
class_names = dataloader.dataset.classes
inputs, classes = next(iter(dataloader))
out = torchvision.utils.make_grid(inputs) # Make a grid from batch
imshow(out, title=[class_names[x] for x in classes])


def visualize_model(model, dataloader, num_images=6):
class_names = dataloader.dataset.classes
was_training = model.training
model.eval()
Expand Down
Binary file modified images/screenshot.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
86 changes: 43 additions & 43 deletions notebooks/torch_inference.ipynb
Original file line number Diff line number Diff line change
@@ -1,22 +1,13 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "vlu1ATx-bOzg"
},
"outputs": [],
"source": [
"%%capture\n",
"!python -m pip install abraia\n",
Expand All @@ -28,13 +19,7 @@
" abraia_key = '' #@param {type: \"string\"}\n",
" %env ABRAIA_ID=$abraia_id\n",
" %env ABRAIA_KEY=$abraia_key"
],
"metadata": {
"cellView": "form",
"id": "vlu1ATx-bOzg"
},
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "code",
Expand All @@ -44,50 +29,65 @@
},
"outputs": [],
"source": [
"import torch\n",
"from abraia.torch import load_classes, load_model\n",
"\n",
"dataset = 'hymenoptera_data'\n",
"\n",
"class_names = load_classes(os.path.join(dataset, 'model_ft.txt'))\n",
"model = load_model(os.path.join(dataset, 'model_ft.pt'))\n",
"model.eval()"
"model = load_model(os.path.join(dataset, 'model_ft.pt'), class_names)\n",
"# model.eval()"
]
},
{
"cell_type": "code",
"source": [
"from abraia.torch import read_image, transform\n",
"\n",
"img = read_image('dog.jpg')\n",
"\n",
"batch = transform(img).unsqueeze(0)\n",
"prediction = model(batch).squeeze(0).softmax(0)\n",
"class_id = prediction.argmax().item()\n",
"score = prediction[class_id].item()\n",
"print(class_names[class_id], score)\n",
"\n",
"_, indices = torch.sort(prediction, descending=True)\n",
"print([(class_names[idx], prediction[idx].item()) for idx in indices])"
],
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "SsGmZXtlC7Jg",
"outputId": "6291cc72-dc5c-4c8a-c10b-1372e03140a7"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"output_type": "stream",
"text": [
"dogs 0.9997120499610901\n",
"[('dogs', 0.9997120499610901), ('ants', 0.00015750189777463675), ('bees', 7.2314782300964e-05), ('cats', 5.8050783991348e-05)]\n"
]
}
],
"source": [
"import torch\n",
"from abraia.torch import read_image, transform\n",
"\n",
"img = read_image(os.path.join(dataset, 'dog.jpg'))\n",
"\n",
"batch = transform(img).unsqueeze(0)\n",
"prediction = model(batch).squeeze(0).softmax(0)\n",
"class_id = prediction.argmax().item()\n",
"score = prediction[class_id].item()\n",
"print(class_names[class_id], score)\n",
"\n",
"_, indices = torch.sort(prediction, descending=True)\n",
"print([(class_names[idx], prediction[idx].item()) for idx in indices])"
]
}
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
155 changes: 155 additions & 0 deletions notebooks/torch_onnx.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"cellView": "form",
"id": "-cp253OYk0zk"
},
"outputs": [],
"source": [
"%%capture\n",
"!python -m pip install abraia\n",
"!python -m pip install onnx onnxruntime\n",
"\n",
"import os\n",
"if not os.getenv('ABRAIA_ID') and not os.getenv('ABRAIA_KEY'):\n",
" #@markdown <a href=\"https://abraia.me/console/gallery\" target=\"_blank\">Upload and manage your images</a>\n",
" abraia_id = '' #@param {type: \"string\"}\n",
" abraia_key = '' #@param {type: \"string\"}\n",
" %env ABRAIA_ID=$abraia_id\n",
" %env ABRAIA_KEY=$abraia_key\n",
"\n",
"from abraia import Abraia\n",
"\n",
"multiple = Abraia()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torchvision\n",
"\n",
"dummy_input = torch.randn(1, 3, 224, 224)\n",
"model = torchvision.models.alexnet(pretrained=True)\n",
"model.eval()\n",
"\n",
"input_names = [\"input1\"]\n",
"output_names = [\"output1\"]\n",
"\n",
"torch.onnx.export(\n",
" model,\n",
" dummy_input,\n",
" \"assets/model.onnx\",\n",
" verbose=True,\n",
" input_names=input_names,\n",
" output_names=output_names,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "EmW59CFnnBHS"
},
"outputs": [],
"source": [
"import torch\n",
"from abraia.torch import load_classes, load_model, read_image, transform\n",
"\n",
"dataset = 'hymenoptera_data'\n",
"\n",
"class_names = load_classes(os.path.join(dataset, 'model_ft.txt'))\n",
"model = load_model(os.path.join(dataset, 'model_ft.pt'), class_names)\n",
"# model.eval()\n",
"\n",
"img = read_image(os.path.join(dataset, 'dog.jpg'))\n",
"input_batch = transform(img).unsqueeze(0)\n",
"# TODO: Define export_onnx\n",
"torch.onnx.export(model, input_batch, 'model_ft.onnx', export_params=True, opset_version=10, do_constant_folding=True, input_names=['input'], output_names=['output'])\n",
"# onnx_model = onnx.load('model_ft.onnx')\n",
"# onnx.checker.check_model(onnx_model)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "YkUzYb9dmcXO",
"outputId": "a99d48c9-55c2-42a4-aafc-750ce117719a"
},
"outputs": [
{
"data": {
"text/plain": [
"{'dogs': 3.2437400817871094,\n",
" 'bees': -1.3845250606536865,\n",
" 'ants': -1.5448365211486816,\n",
" 'cats': -2.4362499713897705}"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import time\n",
"import numpy as np\n",
"import onnxruntime as ort\n",
"from PIL import Image\n",
"\n",
"def preprocess(img):\n",
" '''The function takes loaded image and returns processed tensor.'''\n",
" img = np.array(img.resize((256, 256))).astype(np.float32)\n",
" #center crop\n",
" rm_pad = (256-224)//2\n",
" img = img[rm_pad:-rm_pad,rm_pad:-rm_pad]\n",
" #normalize by mean + std\n",
" img = (img / 255. - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])\n",
" img = img.transpose((2, 0, 1))\n",
" img = np.expand_dims(img, axis=0)\n",
" return img\n",
"\n",
"def predict(path):\n",
" img = Image.open(path)\n",
" img_batch = preprocess(img)\n",
" outputs = ort_session.run(None, {\"input\": img_batch.astype(np.float32)})\n",
" a = np.argsort(-outputs[0].flatten())\n",
" results = {}\n",
" for i in a[0:5]:\n",
" results[labels[i]]=float(outputs[0][0][i])\n",
" return results\n",
"\n",
"ort_session = ort.InferenceSession(\"model_ft.onnx\", providers=['CPUExecutionProvider'])\n",
"\n",
"labels = class_names\n",
"image_path = \"/tmp/hymenoptera_data/dog.jpg\"\n",
"predict(image_path)"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Loading

0 comments on commit a116f17

Please sign in to comment.