diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 5778b25fa0f6..af098771a521 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1322,14 +1322,10 @@ def _impl(inputs, attr, params, mod): def _fill(): def _impl(inputs, attr, params, mod): - output_shape = attr['_output_shapes'][0] - # Output shape must be defined to avoid errors. If any axis is not, we must - # try to compute its shape. - if output_shape is None or -1 in output_shape: - try: - output_shape = _expr.Constant(_infer_value(inputs[0], params, mod)) - except Exception: - output_shape = inputs[0] + try: + output_shape = _infer_value(inputs[0], params, mod).asnumpy().tolist() + except Exception: + output_shape = inputs[0] return _op.full(inputs[1], output_shape, attr['T'].name) return _impl diff --git a/python/tvm/relay/frontend/tensorflow_parser.py b/python/tvm/relay/frontend/tensorflow_parser.py index fdbb8768597f..771aed06ac10 100644 --- a/python/tvm/relay/frontend/tensorflow_parser.py +++ b/python/tvm/relay/frontend/tensorflow_parser.py @@ -30,6 +30,10 @@ class TFParser(object): model_dir : tensorflow frozen pb file or a directory that contains saved model or checkpoints. + outputs : List of output tensor names (Optional) + Optional output node names. This will be protected for saved model + when we do remove training nodes. + Examples -------- .. code-block:: python @@ -38,11 +42,12 @@ class TFParser(object): graphdef = parser.parse() """ - def __init__(self, model_dir): + def __init__(self, model_dir, outputs=None): from tensorflow.core.framework import graph_pb2 self._tmp_dir = util.tempdir() self._model_dir = model_dir self._graph = graph_pb2.GraphDef() + self._outputs = outputs or [] def _set_graph(self, graph): """Set Graph""" @@ -128,7 +133,8 @@ def _load_saved_model(self): output_graph_def = graph_pb2.GraphDef() with open(output_graph_filename, "rb") as f: output_graph_def.ParseFromString(f.read()) - output_graph_def = graph_util.remove_training_nodes(output_graph_def) + output_graph_def = graph_util.remove_training_nodes(output_graph_def, + protected_nodes=self._outputs) return output_graph_def def _load_ckpt(self):