Skip to content

Commit

Permalink
Add export_onnx to torch
Browse files Browse the repository at this point in the history
  • Loading branch information
jrrodri committed Dec 4, 2023
1 parent ed40176 commit 0e13f0d
Show file tree
Hide file tree
Showing 5 changed files with 311 additions and 164 deletions.
13 changes: 12 additions & 1 deletion abraia/torch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import print_function, division
from .multiple import Multiple, tempdir

import onnx
import torch
import torchvision
from torchvision import models, transforms
Expand Down Expand Up @@ -67,7 +68,6 @@ def create_model(class_names, pretrained=True):
num_ftrs = model.fc.in_features
model.fc = torch.nn.Linear(num_ftrs, len(class_names))
model.to(device)
print('device', device)
return model


Expand All @@ -86,6 +86,17 @@ def load_model(path, class_names):
return model


def export_onnx(path, model, device='cpu'):
model.to(device)
dummy_input = torch.randn(1, 3, 224, 224)
src = os.path.join(tempdir, path)
os.makedirs(os.path.dirname(src), exist_ok=True)
torch.onnx.export(model, dummy_input, src, export_params=True, opset_version=10, do_constant_folding=True, input_names=['input'], output_names=['output'])
onnx_model = onnx.load(src)
onnx.checker.check_model(onnx_model)
multiple.upload_file(src, path)


def save_classes(path, class_names):
txt = '\n'.join(class_names)
multiple.save_file(path, txt)
Expand Down
156 changes: 97 additions & 59 deletions notebooks/torch_onnx.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 27,
"metadata": {
"cellView": "form",
"id": "-cp253OYk0zk"
Expand All @@ -28,113 +28,151 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"execution_count": 28,
"metadata": {
"id": "dCkPRZpkJ3Rv"
},
"outputs": [],
"source": [
"import torch\n",
"import torchvision\n",
"\n",
"dummy_input = torch.randn(1, 3, 224, 224)\n",
"model = torchvision.models.alexnet(pretrained=True)\n",
"model.eval()\n",
"\n",
"input_names = [\"input1\"]\n",
"output_names = [\"output1\"]\n",
"\n",
"torch.onnx.export(\n",
" model,\n",
" dummy_input,\n",
" \"assets/model.onnx\",\n",
" verbose=True,\n",
" input_names=input_names,\n",
" output_names=output_names,\n",
")"
"# import torch\n",
"# import torchvision\n",
"\n",
"# dummy_input = torch.randn(1, 3, 224, 224)\n",
"# model = torchvision.models.mobilenet_v2(pretrained=True)\n",
"# model.eval()\n",
"\n",
"# torch.onnx.export(model, dummy_input, \"model.onnx\", verbose=True, input_names=['input'], output_names=['output'])\n",
"\n",
"# multiple.upload_file(\"model.onnx\", \"camera/model.onnx\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 29,
"metadata": {
"id": "EmW59CFnnBHS"
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "EmW59CFnnBHS",
"outputId": "c517d24c-a570-4db1-839a-3b26c7d97a03"
},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.\n",
" warnings.warn(msg)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"device cpu\n"
]
}
],
"source": [
"import onnx\n",
"import torch\n",
"from abraia.torch import load_classes, load_model, read_image, transform\n",
"from abraia.torch import load_classes, load_model\n",
"from abraia.multiple import tempdir\n",
"\n",
"dataset = 'hymenoptera_data'\n",
"\n",
"class_names = load_classes(os.path.join(dataset, 'model_ft.txt'))\n",
"model = load_model(os.path.join(dataset, 'model_ft.pt'), class_names)\n",
"# model.eval()\n",
"\n",
"img = read_image(os.path.join(dataset, 'dog.jpg'))\n",
"input_batch = transform(img).unsqueeze(0)\n",
"# TODO: Define export_onnx\n",
"torch.onnx.export(model, input_batch, 'model_ft.onnx', export_params=True, opset_version=10, do_constant_folding=True, input_names=['input'], output_names=['output'])\n",
"# onnx_model = onnx.load('model_ft.onnx')\n",
"# onnx.checker.check_model(onnx_model)"
"\n",
"def export_onnx(path, model):\n",
" dummy_input = torch.randn(1, 3, 224, 224)\n",
" src = os.path.join(tempdir, path)\n",
" os.makedirs(os.path.dirname(src), exist_ok=True)\n",
" torch.onnx.export(model, dummy_input, src, export_params=True, opset_version=10, do_constant_folding=True, input_names=['input'], output_names=['output'])\n",
" onnx_model = onnx.load(src)\n",
" onnx.checker.check_model(onnx_model)\n",
" return multiple.upload_file(src, path)\n",
"\n",
"\n",
"model_path = export_onnx(os.path.join(dataset, 'model_ft.onnx'), model)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 30,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "YkUzYb9dmcXO",
"outputId": "a99d48c9-55c2-42a4-aafc-750ce117719a"
"outputId": "c3be5bfc-92e3-4087-c1a0-0ce042bde614"
},
"outputs": [
{
"data": {
"text/plain": [
"{'dogs': 3.2437400817871094,\n",
" 'bees': -1.3845250606536865,\n",
" 'ants': -1.5448365211486816,\n",
" 'cats': -2.4362499713897705}"
"{'dogs': 2.679403066635132,\n",
" 'bees': -0.37143513560295105,\n",
" 'ants': -1.542833685874939,\n",
" 'cats': -1.6038011312484741}"
]
},
"execution_count": 6,
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import time\n",
"import numpy as np\n",
"import onnxruntime as ort\n",
"from PIL import Image\n",
"\n",
"\n",
"ort_session = ort.InferenceSession(f\"/tmp/{model_path}\", providers=['CPUExecutionProvider'])\n",
"\n",
"\n",
"def resize(img, size):\n",
" width = size if img.height > img.width else round(size * img.width / img.height)\n",
" height = round(size * img.height / img.width) if img.height > img.width else size\n",
" return img.resize((width, height))\n",
"\n",
"\n",
"def crop(img, size):\n",
" left, top = (img.width - size) // 2, (img.height - size) // 2\n",
" right, bottom = left + size, top + size\n",
" return img.crop((left, top, right, bottom))\n",
"\n",
"\n",
"def normalize(img, mean, std):\n",
" img = (np.array(img) / 255. - np.array(mean)) / np.array(std)\n",
" return img.astype(np.float32)\n",
"\n",
"\n",
"def preprocess(img):\n",
" '''The function takes loaded image and returns processed tensor.'''\n",
" img = np.array(img.resize((256, 256))).astype(np.float32)\n",
" #center crop\n",
" rm_pad = (256-224)//2\n",
" img = img[rm_pad:-rm_pad,rm_pad:-rm_pad]\n",
" #normalize by mean + std\n",
" img = (img / 255. - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])\n",
" img = img.transpose((2, 0, 1))\n",
" img = np.expand_dims(img, axis=0)\n",
" return img\n",
"\n",
"def predict(path):\n",
" img = Image.open(path)\n",
" img_batch = preprocess(img)\n",
" outputs = ort_session.run(None, {\"input\": img_batch.astype(np.float32)})\n",
" img = resize(img, 256)\n",
" img = crop(img, 224)\n",
" img = normalize(img, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
" return np.expand_dims(img.transpose((2, 0, 1)), axis=0)\n",
"\n",
"\n",
"def predict(src):\n",
" img = Image.open(src)\n",
" input = preprocess(img)\n",
" outputs = ort_session.run(None, {\"input\": input})\n",
" a = np.argsort(-outputs[0].flatten())\n",
" results = {}\n",
" for i in a[0:5]:\n",
" results[labels[i]]=float(outputs[0][0][i])\n",
" results[class_names[i]]=float(outputs[0][0][i])\n",
" return results\n",
"\n",
"ort_session = ort.InferenceSession(\"model_ft.onnx\", providers=['CPUExecutionProvider'])\n",
"\n",
"labels = class_names\n",
"image_path = \"/tmp/hymenoptera_data/dog.jpg\"\n",
"predict(image_path)"
"filename = 'dog.jpg'\n",
"multiple.download_file(os.path.join(dataset, filename), filename)\n",
"predict(filename)"
]
}
],
Expand Down
300 changes: 199 additions & 101 deletions notebooks/torch_training.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions scripts/abraia
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def input_files(src):


@click.group('abraia')
@click.version_option('0.12.3')
@click.version_option('0.12.4')
def cli():
"""Abraia CLI tool"""
pass
Expand All @@ -64,7 +64,7 @@ def configure():
@cli.command()
def info():
"""Show user account information"""
click.echo('abraia, version 0.12.3\n')
click.echo('abraia, version 0.12.4\n')
click.echo('Go to [' + click.style('https://abraia.me/console/', fg='green') + '] to see your account information\n')


Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

setup(
name='abraia',
version='0.12.3',
version='0.12.4',
description='Abraia Multiple SDK',
long_description=long_description,
long_description_content_type='text/markdown',
Expand Down

0 comments on commit 0e13f0d

Please sign in to comment.