Skip to content

Commit

Permalink
Fix tf parser (apache#5794)
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinthesun authored and Trevor Morris committed Jun 18, 2020
1 parent 0de9ca9 commit 7f23072
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
12 changes: 4 additions & 8 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1322,14 +1322,10 @@ def _impl(inputs, attr, params, mod):

def _fill():
def _impl(inputs, attr, params, mod):
output_shape = attr['_output_shapes'][0]
# Output shape must be defined to avoid errors. If any axis is not, we must
# try to compute its shape.
if output_shape is None or -1 in output_shape:
try:
output_shape = _expr.Constant(_infer_value(inputs[0], params, mod))
except Exception:
output_shape = inputs[0]
try:
output_shape = _infer_value(inputs[0], params, mod).asnumpy().tolist()
except Exception:
output_shape = inputs[0]

return _op.full(inputs[1], output_shape, attr['T'].name)
return _impl
Expand Down
10 changes: 8 additions & 2 deletions python/tvm/relay/frontend/tensorflow_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ class TFParser(object):
model_dir : tensorflow frozen pb file or a directory that contains saved
model or checkpoints.
outputs : List of output tensor names (Optional)
Optional output node names. This will be protected for saved model
when we do remove training nodes.
Examples
--------
.. code-block:: python
Expand All @@ -38,11 +42,12 @@ class TFParser(object):
graphdef = parser.parse()
"""

def __init__(self, model_dir):
def __init__(self, model_dir, outputs=None):
from tensorflow.core.framework import graph_pb2
self._tmp_dir = util.tempdir()
self._model_dir = model_dir
self._graph = graph_pb2.GraphDef()
self._outputs = outputs or []

def _set_graph(self, graph):
"""Set Graph"""
Expand Down Expand Up @@ -128,7 +133,8 @@ def _load_saved_model(self):
output_graph_def = graph_pb2.GraphDef()
with open(output_graph_filename, "rb") as f:
output_graph_def.ParseFromString(f.read())
output_graph_def = graph_util.remove_training_nodes(output_graph_def)
output_graph_def = graph_util.remove_training_nodes(output_graph_def,
protected_nodes=self._outputs)
return output_graph_def

def _load_ckpt(self):
Expand Down

0 comments on commit 7f23072

Please sign in to comment.