diff --git a/nnvm/python/nnvm/frontend/tensorflow.py b/nnvm/python/nnvm/frontend/tensorflow.py index 3a7965f97bbbd..609813d268791 100644 --- a/nnvm/python/nnvm/frontend/tensorflow.py +++ b/nnvm/python/nnvm/frontend/tensorflow.py @@ -36,6 +36,7 @@ def __call__(self, inputs, attrs, *args): self._ignores.append('_node_name') self._ignores.append('is_training') self._ignores.append('_target_layout') + self._ignores.append('_input_0d_mismatch') # Retain the names try: attrs['name'] = attrs['_node_name'] @@ -315,8 +316,7 @@ def _impl(inputs, attr, params): dim_input = inputs.pop(1) axis = params[dim_input.list_output_names()[0]] params.pop(dim_input.list_output_names()[0]) - return AttrCvt(op_name="expand_dims", ignores=['Tdim'], - extras={'axis': axis.asnumpy()[0]})(inputs, attr) + return _expand_dims_0d_aware(inputs[0], attr, axis=axis.asnumpy()[0]) return _impl def _resize_bilinear(): @@ -379,7 +379,7 @@ def _impl(inputs, attr, params): def _pack(): def _impl(inputs, attr, params): axis = int(attr["axis"]) - inputs_reshaped = [_sym.expand_dims(i, axis=axis, num_newaxis=1) for i in inputs] + inputs_reshaped = [_expand_dims_0d_aware(i, attr, axis=axis, num_newaxis=1) for i in inputs] return _sym.concatenate(*inputs_reshaped, axis=axis, name=attr["_node_name"]) return _impl @@ -828,6 +828,13 @@ def _impl(inputs, attr, params): )(inputs, attr) return _impl +def _expand_dims_0d_aware(data, attr, axis, num_newaxis=1): + if data in attr['_input_0d_mismatch']: + return data if num_newaxis == 1 else \ + _sym.expand_dims(data, axis=axis, num_newaxis=num_newaxis-1) + + return _sym.expand_dims(data, axis=axis, num_newaxis=num_newaxis) + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -1092,6 +1099,7 @@ def __init__(self): self._output_shapes = {} self._num_param = 0 self._num_rnn_layer = False + self._outputs_are_0d = {} def from_tensorflow(self, graph, layout="NHWC", shape=None): """Construct nnvm nodes from tensorflow graph definition - GraphDef. @@ -1146,6 +1154,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None): # Operator name 'Const' is treated as a parameter to build NNVM params dict. input_shapes = {} + input_0d_mismatch = set() attr = self._parse_attr(node.attr) #Variable converted to Const will not have only value attr @@ -1165,6 +1174,9 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None): else: raise NotImplementedError( \ "Please freeze the graph with add_shapes=True") + self._outputs_are_0d[node.name] = [ \ + not shape if isinstance(shape, list) else False \ + for shape in self._output_shapes[node.name]] if node.op == "Placeholder": self._nodes[node.name] = _sym.Variable(name=node.name, @@ -1210,12 +1222,18 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None): if i in self._nodes: tvm_n = self._nodes[i] - outputs = tvm_n.list_output_names() - if len(outputs) > 1: + tvm_n_shape = self._output_shapes[i] + if len(tvm_n.list_output_names()) > 1: tvm_n = tvm_n[num_layer] + tvm_n_shape = [tvm_n_shape[num_layer]] inputs.append(tvm_n) - input_shapes[tvm_n] = self._output_shapes[i] + input_shapes[tvm_n] = tvm_n_shape + #This means the node is 1d in NVM and 0d in TF. + #See `_expand_dims_0d_aware`. + if self._outputs_are_0d[i][num_layer] and tvm_n_shape[0]: + input_0d_mismatch.add(tvm_n) attr['_input_shapes'] = input_shapes + attr['_input_0d_mismatch'] = input_0d_mismatch inputs = self._fix_extranodes(node.op, attr, inputs) op = self._convert_operator(node.op, inputs, attr, graph) diff --git a/nnvm/tests/python/frontend/tensorflow/test_forward.py b/nnvm/tests/python/frontend/tensorflow/test_forward.py index 95e2558268612..c9406c3c2a62e 100644 --- a/nnvm/tests/python/frontend/tensorflow/test_forward.py +++ b/nnvm/tests/python/frontend/tensorflow/test_forward.py @@ -505,14 +505,21 @@ def test_forward_gather(): # ------ def _test_split(ip_shape, num_or_size_splits, axis, dtype): + np_data = np.random.uniform(-5, 5, size=ip_shape).astype(dtype) + num_split = len(num_or_size_splits) if isinstance(num_or_size_splits, list) else num_or_size_splits + tf.reset_default_graph() in_data = tf.placeholder(dtype, ip_shape, name="in_data") - num_split = len(num_or_size_splits) if isinstance(num_or_size_splits, list) else num_or_size_splits tf.split(in_data, num_or_size_splits, axis=axis) - np_data = np.random.uniform(-5, 5, size=ip_shape).astype(dtype) compare_tf_with_tvm([np_data], ['in_data:0'], [f'split:{n}' for n in range(num_split)]) + tf.reset_default_graph() + in_data = tf.placeholder(dtype, ip_shape, name="in_data") + tf.concat(tf.split(in_data, num_or_size_splits, axis=axis), axis=axis) + + compare_tf_with_tvm([np_data], ['in_data:0'], 'concat:0') + def test_forward_split(): '''test split layer''' _test_split((6,), 2, 0, 'int32') @@ -530,13 +537,20 @@ def test_forward_split(): # ------ def _test_unstack(ip_shape, axis, dtype): + np_data = np.random.uniform(-5, 5, size=ip_shape).astype(dtype) + tf.reset_default_graph() in_data = tf.placeholder(dtype, ip_shape, name="in_data") tf.unstack(in_data, axis=axis) - np_data = np.random.uniform(-5, 5, size=ip_shape).astype(dtype) compare_tf_with_tvm([np_data], ['in_data:0'], [f'unstack:{n}' for n in range(ip_shape[axis])]) + tf.reset_default_graph() + in_data = tf.placeholder(dtype, ip_shape, name="in_data") + tf.stack(tf.unstack(in_data, axis=axis), axis=axis) + + compare_tf_with_tvm([np_data], ['in_data:0'], 'stack:0') + def test_forward_unstack(): '''test unstack layer''' _test_unstack((6,), 0, 'int32')