Skip to content

Commit

Permalink
remove tflite part
Browse files Browse the repository at this point in the history
Signed-off-by: xavier dupré <xavier.dupre@gmail.com>
  • Loading branch information
sdpython committed Sep 18, 2020
1 parent 20a9deb commit 1bcb143
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 89 deletions.
40 changes: 0 additions & 40 deletions examples/end2end_tfhub.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,43 +73,3 @@
number=10, globals=globals()))
print('ort:', timeit.timeit("session.run(None, {'input_1:0': input})",
number=10, globals=globals()))

########################################
# Freezes the graph with tensorflow.lite.
converter = tf.lite.TFLiteConverter.from_saved_model("efficientnetb0clas")
tflite_model = converter.convert()
with open("efficientnetb0clas.tflite", "wb") as f:
f.write(tflite_model)

# Builds an interpreter.
interpreter = tf.lite.Interpreter(model_path='efficientnetb0clas.tflite')
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print("input_details", input_details)
print("output_details", output_details)
index = input_details[0]['index']


def tflite_predict(input, interpreter=interpreter, index=index):
res = []
for i in range(input.shape[0]):
interpreter.set_tensor(index, input[i:i + 1])
interpreter.invoke()
res.append(interpreter.get_tensor(output_details[0]['index']))
return np.vstack(res)


print(input[0:1].shape, "----", input_details[0]['shape'])
output_data = tflite_predict(input, interpreter, index)
print(output_data)

########################################
# Measures processing time again.

print('tf:', timeit.timeit('model.predict(input)',
number=10, globals=globals()))
print('ort:', timeit.timeit("session.run(None, {'input_1:0': input})",
number=10, globals=globals()))
print('tflite:', timeit.timeit('tflite_predict(input)',
number=10, globals=globals()))
49 changes: 0 additions & 49 deletions examples/end2end_tfkeras.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,52 +68,3 @@
number=100, globals=globals()))
print('ort:', timeit.timeit("session.run(None, {'input_1:0': input})",
number=100, globals=globals()))

########################################
# Freezes the graph with tensorflow.lite.
converter = tf.lite.TFLiteConverter.from_saved_model("simple_rnn")
tflite_model = converter.convert()
with open("simple_rnn.tflite", "wb") as f:
f.write(tflite_model)

# Builds an interpreter.
interpreter = tf.lite.Interpreter(model_path='simple_rnn.tflite')
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print("input_details", input_details)
print("output_details", output_details)
index = input_details[0]['index']


def tflite_predict(input, interpreter=interpreter, index=index):
res = []
for i in range(input.shape[0]):
interpreter.set_tensor(index, input[i:i + 1])
interpreter.invoke()
res.append(interpreter.get_tensor(output_details[0]['index']))
return np.vstack(res)


print(input[0:1].shape, "----", input_details[0]['shape'])
output_data = tflite_predict(input, interpreter, index)
print(output_data)

########################################
# Measures processing time again.

print('tf:', timeit.timeit('model.predict(input)',
number=100, globals=globals()))
print('ort:', timeit.timeit("session.run(None, {'input_1:0': input})",
number=100, globals=globals()))
print('tflite:', timeit.timeit('tflite_predict(input)',
number=100, globals=globals()))

########################################
# Measures processing time only between onnxruntime and
# tensorflow lite with more loops.

print('ort:', timeit.timeit("session.run(None, {'input_1:0': input})",
number=10000, globals=globals()))
print('tflite:', timeit.timeit('tflite_predict(input)',
number=10000, globals=globals()))

0 comments on commit 1bcb143

Please sign in to comment.