Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[relay][frontend] TensorFlow saved model support #2586

Merged
merged 3 commits into from
Mar 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
59 changes: 44 additions & 15 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from __future__ import print_function

import logging
import warnings
# Numpy support
import numpy as np

Expand Down Expand Up @@ -410,7 +411,7 @@ def _impl(inputs, attr, params):
def _decode_image():
def _impl(inputs, attr, params):
# Image decode wrapper: Expecting user to feed decoded input to next layer drop this layer.
print("DecodeJpeg: It's a pass through, please handle preprocessing before input")
warnings.warn("DecodeJpeg: It's a pass through, please handle preprocessing before input")
return inputs[0]
return _impl

Expand Down Expand Up @@ -1178,6 +1179,7 @@ class GraphProto(object):
def __init__(self):
self._nodes = {}
self._params = {}
self._input_shapes = {}
self._output_shapes = {}
self._num_param = 0
self._num_rnn_layer = False
Expand Down Expand Up @@ -1229,36 +1231,55 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
raise NotImplementedError( \
"The following operators are not implemented: {}".format(missing_operators))

for node in graph.node:
if node.op == 'Placeholder':
if shape and node.name in shape:
self._input_shapes[node.name] = list(shape[node.name])
continue
self._input_shapes[node.name] = \
tensor_util.TensorShapeProtoToList(node.attr['shape'].shape)
for idx, dim in enumerate(self._input_shapes[node.name]):
if dim < 0:
self._input_shapes[node.name][idx] = 1
warnings.warn("Use 1 instead of -1 in shape of operator %s."
% node.name)

# Ignore user's input shape for Non placeholder
elif node.op == 'Const':
tensor_value = node.attr['value'].tensor
self._input_shapes[node.name] = \
tensor_util.TensorShapeProtoToList(tensor_value.tensor_shape)
if shape and node.name in shape:
warnings.warn("Ignore the passed shape. Shape in graphdef "
"will be used for operator %s." % node.name)

# Parse the nodes to re-create TF graph using Relay operators.
for node in graph.node:
# Tensorflow doesn't have seperate list for params extraction.
# Tensorflow doesn't have separate list for params extraction.
# Operator name 'Const' is treated as a parameter to build params dict.

input_shapes = {}
attr = self._parse_attr(node.attr)

#Variable converted to Const will not have only value attr
# Variable converted to Const will not have only value attr
if 'value' in attr and node.op == 'Const':
tensor_value = attr['value']
self._output_shapes[node.name] = \
[tensor_util.TensorShapeProtoToList( \
tensor_value.tensor_shape)]
self._output_shapes[node.name] = [self._input_shapes[node.name]]
elif shape and node.name in shape:
# Give priority to user argument.
self._output_shapes[node.name] = [shape[node.name]]
elif '_output_shapes' in attr:
self._output_shapes[node.name] = \
[tensor_util.TensorShapeProtoToList(tshape) \
for tshape in attr['_output_shapes']]
elif shape:
else:
# Keep the list indexable to avoid key error.
# Actual value will be filled after node creation.
self._output_shapes[node.name] = [None]
else:
raise NotImplementedError( \
"Please freeze the graph with add_shapes=True")

if node.op == "Placeholder":
self._output_shapes[node.name] = [shape[node.name]]
self._output_shapes[node.name] = [self._input_shapes[node.name]]
self._nodes[node.name] = [_expr.var(node.name,
shape=self._output_shapes[node.name][0],
shape=self._input_shapes[node.name],
dtype=attr['dtype'].name)]

elif node.op == "Const":
Expand All @@ -1274,7 +1295,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):

else:
# Pass the parsed shapes instead
attr["_output_shapes"] = self._output_shapes[node.name]
attr["_output_shapes"] = output_shapes = self._output_shapes[node.name]

# Pass the node name too in attr
attr["_node_name"] = node.name
Expand All @@ -1301,7 +1322,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):

op = self._convert_operator(node.op, inputs, attr, graph)

# Check is op is converted to param
# Check if op is converted to param
if isinstance(op, np.ndarray):
self._params[node.name] = tvm.nd.array(op)
op = [_expr.var(node.name,
Expand All @@ -1317,6 +1338,14 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):

self._nodes[node.name] = op

# Infer shapes even without specifying "add_shapes=True"
if output_shapes == [None]:
out_type = ir_pass.infer_type(self._nodes[node.name][0])
self._output_shapes[node.name] = [get_const_tuple(out_type.checked_type.shape)]

if self._output_shapes[node.name] and shape and node.name in shape:
assert self._output_shapes[node.name] == list(shape[node.name])

# Infer shapes if passed explicitely
node_output = self._nodes[node.name]
out_type = ir_pass.infer_type(node_output[0])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,21 @@


class TFParser(object):
"""A Wrapper to handle tensorflow models parsing
TensorFlow is needed
```
parser = TfParser(model_dir)
graph = parser.parse()
```
"""
A Wrapper to handle tensorflow models parsing, TensorFlow is needed

Parameters
----------
model_dir : tensorflow frozen pb file or a directory that contains saved
model or checkpoints.

Examples
--------
.. code-block:: python

parser = TfParser(model_dir)
graph = parser.parse()
# graph is related graphdef of the model
"""

def __init__(self, model_dir):
Expand Down Expand Up @@ -115,13 +120,16 @@ def _load_ckpt(self):
"""TODO: Load checkpoint model."""
raise RuntimeError("InputConfiguration: Loading tf checkpoint model is "
"not supported yet.")
# pylint: disable=unreachable
return 0

def parse(self):
"""Parse tensorflow models: checkpoints, saved models, and single pb
file.
"""
Parse tensorflow models: checkpoints, saved models, and single frozen pb file.

Returns
-------
GraphDef of the passed model
"""

graph = None

if os.path.isdir(self._model_dir):
Expand Down