Skip to content

Commit

Permalink
Hotfix: Circumvent tf-2.12 breaking change on tflite subgraph API to …
Browse files Browse the repository at this point in the history
…unbreak UT

TF-2.12.0 introduced API change that breaks tf2onnx UT tests on the
tflite paths, due to the addition of compulsory subgraph arg to several
function's input signature:
tensorflow/tensorflow@55d84d7

This commit is a temporary hotfix to unbreak related UT failure.
Existing tf2onnx's use cases get tflite Interpreter's tensors from model's
first subgraph only. The hotfix hard-codes subgraph index to `0` to
retain the same behavior while resolves API diff.

Signed-off-by: Yu Cong <congyc@amazon.com>
  • Loading branch information
Yu Cong committed Jul 13, 2023
1 parent 25c977c commit df33fd2
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 17 deletions.
13 changes: 5 additions & 8 deletions tests/backend_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.ops import lookup_ops
import onnx
from packaging.version import Version
from common import get_test_config
from tfjs_runner import run_tfjs
from tf2onnx import constants
Expand All @@ -26,7 +27,7 @@
from tf2onnx import optimizer
from tf2onnx.tf_loader import tf_reset_default_graph, tf_session, tf_placeholder, from_function, freeze_session
from tf2onnx.tf_loader import tf_optimize, is_tf2, get_hash_table_info
from tf2onnx.tf_utils import compress_graph_def
from tf2onnx.tf_utils import compress_graph_def, get_tf_version
from tf2onnx.graph import ExternalTensorStorage
from tf2onnx.tflite.Model import Model

Expand Down Expand Up @@ -249,14 +250,10 @@ def convert_to_tflite(self, graph_def, feed_dict, outputs):

def tflite_has_supported_types(self, tflite_path):
try:
with open(tflite_path, 'rb') as f:
buf = f.read()
buf = bytearray(buf)
model = Model.GetRootAsModel(buf, 0)
tensor_cnt = model.Subgraphs(0).TensorsLength()
interpreter = tf.lite.Interpreter(tflite_path)
for i in range(tensor_cnt):
dtype = interpreter._get_tensor_details(i)['dtype'] # pylint: disable=protected-access
tensor_details = interpreter.get_tensor_details()
for tensor_detail in tensor_details:
dtype = tensor_detail.get('dtype')
if np.dtype(dtype).kind == 'O':
return False
return True
Expand Down
19 changes: 10 additions & 9 deletions tf2onnx/tflite_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
from tensorflow.python.framework import tensor_util
import tensorflow as tf
import numpy as np
from packaging.version import Version
from tf2onnx.tflite.TensorType import TensorType as TFLiteTensorType
from tf2onnx.tflite.Model import Model
from tf2onnx.flexbuffers import read_flexbuffer
from tf2onnx.tf_utils import read_tf_node_def_attrs
from tf2onnx.tf_utils import read_tf_node_def_attrs, get_tf_version
from tf2onnx.graph import Graph
from tf2onnx import utils

Expand Down Expand Up @@ -196,14 +197,14 @@ def read_tflite_model(tflite_path):
try:
interpreter = tf.lite.Interpreter(tflite_path)
interpreter.allocate_tensors()
tensor_cnt = model.Subgraphs(0).TensorsLength()
for i in range(tensor_cnt):
name = model.Subgraphs(0).Tensors(i).Name().decode()
details = interpreter._get_tensor_details(i) # pylint: disable=protected-access
if "shape_signature" in details:
tensor_shapes[name] = details["shape_signature"].tolist()
elif "shape" in details:
tensor_shapes[name] = details["shape"].tolist()
tensor_details = interpreter.get_tensor_details()

for tensor_detail in tensor_details:
name = tensor_detail.get('name')
if "shape_signature" in tensor_detail:
tensor_shapes[name] = tensor_detail["shape_signature"].tolist()
elif "shape" in tensor_detail:
tensor_shapes[name] = tensor_detail["shape"].tolist()
except Exception as e: # pylint: disable=broad-except
logger.warning("Error loading model into tflite interpreter: %s", e)
tflite_graphs = get_model_subgraphs(model)
Expand Down

0 comments on commit df33fd2

Please sign in to comment.