-
Notifications
You must be signed in to change notification settings - Fork 434
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
51 changed files
with
1,313 additions
and
238 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
""" | ||
The following code compares the speed of tensorflow against onnxruntime | ||
with a model downloaded from Tensorflow Hub. | ||
""" | ||
import time | ||
import numpy | ||
from tqdm import tqdm | ||
import tensorflow_hub as hub | ||
import onnxruntime as ort | ||
|
||
|
||
def generate_random_images(shape=(100, 100), n=10): | ||
imgs = [] | ||
for i in range(n): | ||
sh = (1,) + shape + (3,) | ||
img = numpy.clip(numpy.abs(numpy.random.randn(*sh)), 0, 1) * 255 | ||
img = img.astype(numpy.float32) | ||
imgs.append(img) | ||
return imgs | ||
|
||
|
||
def measure_time(fct, imgs): | ||
results = [] | ||
times = [] | ||
for img in tqdm(imgs): | ||
begin = time.perf_counter() | ||
result = fct(img) | ||
end = time.perf_counter() | ||
results.append(result) | ||
times.append(end - begin) | ||
return results, times | ||
|
||
|
||
imgs = generate_random_images() | ||
|
||
# Download model from https://tfhub.dev/captain-pool/esrgan-tf2/1 | ||
# python -m tf2onnx.convert --saved-model esrgan --output "esrgan-tf2.onnx" --opset 12 | ||
ort = ort.InferenceSession('esrgan-tf2.onnx') | ||
fct_ort = lambda img: ort.run(None, {'input_0:0': img}) | ||
results_ort, duration_ort = measure_time(fct_ort, imgs) | ||
print(len(imgs), duration_ort) | ||
|
||
model = hub.load("https://tfhub.dev/captain-pool/esrgan-tf2/1") | ||
results_tf, duration_tf = measure_time(model, imgs) | ||
print(len(imgs), duration_tf) | ||
|
||
print("ratio ORT / TF", sum(duration_ort) / sum(duration_tf)) |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
""" | ||
This example retrieves a model from tensorflowhub. | ||
It is converted into ONNX. Predictions are compared to | ||
the predictions from tensorflow to check there is no | ||
discrepencies. Inferencing time is also compared between | ||
*onnxruntime*, *tensorflow* and *tensorflow.lite*. | ||
""" | ||
from onnxruntime import InferenceSession | ||
import os | ||
import sys | ||
import subprocess | ||
import timeit | ||
import numpy as np | ||
import tensorflow as tf | ||
from tensorflow import keras | ||
from tensorflow.keras import Input | ||
try: | ||
import tensorflow_hub as tfhub | ||
except ImportError: | ||
# no tensorflow_hub | ||
print("tensorflow_hub not installed.") | ||
sys.exit(0) | ||
|
||
######################################## | ||
# Downloads the model. | ||
hub_layer = tfhub.KerasLayer( | ||
"https://tfhub.dev/google/efficientnet/b0/classification/1") | ||
model = keras.Sequential() | ||
model.add(Input(shape=(224, 224, 3), dtype=tf.float32)) | ||
model.add(hub_layer) | ||
print(model.summary()) | ||
|
||
######################################## | ||
# Saves the model. | ||
if not os.path.exists("efficientnetb0clas"): | ||
os.mkdir("efficientnetb0clas") | ||
tf.keras.models.save_model(model, "efficientnetb0clas") | ||
|
||
input_names = [n.name for n in model.inputs] | ||
output_names = [n.name for n in model.outputs] | ||
print('inputs:', input_names) | ||
print('outputs:', output_names) | ||
|
||
######################################## | ||
# Testing the model. | ||
input = np.random.randn(2, 224, 224, 3).astype(np.float32) | ||
expected = model.predict(input) | ||
print(expected) | ||
|
||
######################################## | ||
# Run the command line. | ||
proc = subprocess.run( | ||
'python -m tf2onnx.convert --saved-model efficientnetb0clas ' | ||
'--output efficientnetb0clas.onnx --opset 12'.split(), | ||
capture_output=True) | ||
print(proc.returncode) | ||
print(proc.stdout.decode('ascii')) | ||
print(proc.stderr.decode('ascii')) | ||
|
||
######################################## | ||
# Runs onnxruntime. | ||
session = InferenceSession("efficientnetb0clas.onnx") | ||
got = session.run(None, {'input_1:0': input}) | ||
print(got[0]) | ||
|
||
######################################## | ||
# Measures the differences. | ||
print(np.abs(got[0] - expected).max()) | ||
|
||
######################################## | ||
# Measures processing time. | ||
print('tf:', timeit.timeit('model.predict(input)', | ||
number=10, globals=globals())) | ||
print('ort:', timeit.timeit("session.run(None, {'input_1:0': input})", | ||
number=10, globals=globals())) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
""" | ||
This example builds a simple model without training. | ||
It is converted into ONNX. Predictions are compared to | ||
the predictions from tensorflow to check there is no | ||
discrepencies. Inferencing time is also compared between | ||
*onnxruntime*, *tensorflow* and *tensorflow.lite*. | ||
""" | ||
from onnxruntime import InferenceSession | ||
import os | ||
import subprocess | ||
import timeit | ||
import numpy as np | ||
import tensorflow as tf | ||
from tensorflow import keras | ||
from tensorflow.keras import layers, Input | ||
|
||
######################################## | ||
# Creates the model. | ||
model = keras.Sequential() | ||
model.add(Input((4, 4))) | ||
model.add(layers.SimpleRNN(8)) | ||
model.add(layers.Dense(2)) | ||
print(model.summary()) | ||
input_names = [n.name for n in model.inputs] | ||
output_names = [n.name for n in model.outputs] | ||
print('inputs:', input_names) | ||
print('outputs:', output_names) | ||
|
||
######################################## | ||
# Training | ||
# .... | ||
# Skipped. | ||
|
||
######################################## | ||
# Testing the model. | ||
input = np.random.randn(2, 4, 4).astype(np.float32) | ||
expected = model.predict(input) | ||
print(expected) | ||
|
||
######################################## | ||
# Saves the model. | ||
if not os.path.exists("simple_rnn"): | ||
os.mkdir("simple_rnn") | ||
tf.keras.models.save_model(model, "simple_rnn") | ||
|
||
######################################## | ||
# Run the command line. | ||
proc = subprocess.run('python -m tf2onnx.convert --saved-model simple_rnn ' | ||
'--output simple_rnn.onnx --opset 12'.split(), | ||
capture_output=True) | ||
print(proc.returncode) | ||
print(proc.stdout.decode('ascii')) | ||
print(proc.stderr.decode('ascii')) | ||
|
||
######################################## | ||
# Runs onnxruntime. | ||
session = InferenceSession("simple_rnn.onnx") | ||
got = session.run(None, {'input_1:0': input}) | ||
print(got[0]) | ||
|
||
######################################## | ||
# Measures the differences. | ||
print(np.abs(got[0] - expected).max()) | ||
|
||
######################################## | ||
# Measures processing time. | ||
print('tf:', timeit.timeit('model.predict(input)', | ||
number=100, globals=globals())) | ||
print('ort:', timeit.timeit("session.run(None, {'input_1:0': input})", | ||
number=100, globals=globals())) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.