From d641aee9fa6c72fd169a668c2d7bae09743ac50d Mon Sep 17 00:00:00 2001 From: Alexey Romanov Date: Thu, 6 Dec 2018 17:18:35 +0300 Subject: [PATCH] [FRONTEND][TENSORFLOW] Use input shapes directly instead of 1-element lists --- nnvm/python/nnvm/frontend/tensorflow.py | 41 ++++++++++++------------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/nnvm/python/nnvm/frontend/tensorflow.py b/nnvm/python/nnvm/frontend/tensorflow.py index 47aca3816e6f..a869abac9c4f 100644 --- a/nnvm/python/nnvm/frontend/tensorflow.py +++ b/nnvm/python/nnvm/frontend/tensorflow.py @@ -120,7 +120,7 @@ def _impl(inputs, attr, params): attr['data_format'] = attr['data_format'].decode("utf-8") flip_layout = False - input_shape = attr['_input_shapes'][inputs[0]][0] + input_shape = attr['_input_shapes'][inputs[0]] if attr['data_format'] == 'NHWC': attr['kernel_shape'] = (attr['ksize'][1], attr['ksize'][2]) @@ -132,7 +132,7 @@ def _impl(inputs, attr, params): raise TypeError("Unsupported data_format type : {}".format(attr['data_format'])) if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC": - tmp_shape = attr['_input_shapes'][inputs[0]][0] + tmp_shape = attr['_input_shapes'][inputs[0]] input_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)] inputs[0] = _sym.transpose(inputs[0], axes=(0, 3, 1, 2)) attr['data_format'] = "NCHW" @@ -185,13 +185,13 @@ def _impl(inputs, attr, params): # NCHW Layout require weights transpose if attr['data_format'] == 'NCHW': - tmp_shape = attr['_input_shapes'][inputs[1]][0] + tmp_shape = attr['_input_shapes'][inputs[1]] tmp_shape = [tmp_shape[ii] for ii in (3, 2, 0, 1)] inputs[1] = _sym.transpose(inputs[1], axes=(3, 2, 0, 1)) - attr['_input_shapes'][inputs[1]] = [tmp_shape] + attr['_input_shapes'][inputs[1]] = tmp_shape - input_shape = attr['_input_shapes'][inputs[0]][0] - weights_shape = attr['_input_shapes'][inputs[1]][0] + input_shape = attr['_input_shapes'][inputs[0]] + weights_shape = attr['_input_shapes'][inputs[1]] if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC": input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)] @@ -484,7 +484,7 @@ def _impl(inputs, attr, params): def _shape(): def _impl(inputs, attr, params): - return np.array(attr['_input_shapes'][inputs[0]][0], dtype='int32') + return np.array(attr['_input_shapes'][inputs[0]], dtype='int32') return _impl def _fill(): @@ -565,7 +565,7 @@ def _impl(inputs, attr, params): new_axis_mask = int(attr.get('new_axis_mask', 0)) shrink_axis_mask = int(attr.get('shrink_axis_mask', 0)) data_shape = attr['_input_shapes'][inputs[0]] - data_dim = len(data_shape[0]) + data_dim = len(data_shape) stride_dim = len(stride) def _transform_mask(stride_dim, ellipsis_mask): @@ -596,7 +596,7 @@ def _transform_mask(stride_dim, ellipsis_mask): + new_axes_after_ellipsis), data_dim) for i in range(final_index, to_index): m_begin[final_index] = 0 - m_end[final_index] = data_shape[0][final_index] + m_end[final_index] = data_shape[final_index] m_stride[final_index] = 1 fshape_indices.append(final_index) final_index += 1 @@ -606,19 +606,19 @@ def _transform_mask(stride_dim, ellipsis_mask): if final_index == len(m_begin): break if mask & begin_mask: - m_begin[final_index] = data_shape[0][final_index] \ + m_begin[final_index] = data_shape[final_index] \ if stride[index] < 0 else 0 elif begin[index]: m_begin[final_index] = begin[index] if mask & end_mask: m_end[final_index] = 0 if stride[index] < 0 \ - else data_shape[0][final_index] + else data_shape[final_index] elif end[index]: m_end[final_index] = end[index] m_stride[final_index] = stride[index] if mask & shrink_axis_mask: #Tensorflow make axis with shrink_axis_mask as dimension 1 - m_begin[final_index] = data_shape[0][final_index] + begin[index] \ + m_begin[final_index] = data_shape[final_index] + begin[index] \ if begin[index] < 0 else begin[index] m_end[final_index] = begin[index] + 1 m_stride[final_index] = 1 @@ -684,8 +684,8 @@ def _impl(inputs, in_state_c, in_state_h, attr, params): forget_bias = attr.pop('forget_bias') input_shape = attr['_input_shapes'][inputs[0]] weight_shape = attr['_input_shapes'][inputs[3]] - batch_size, input_size = input_shape[0][0], input_shape[0][1] - num_hidden_layers = weight_shape[0][1] + batch_size, input_size = input_shape[0], input_shape[1] + num_hidden_layers = weight_shape[1] num_hidden = num_hidden_layers // 4 in_data = _sym.reshape(in_data, @@ -741,11 +741,10 @@ def _impl(inputs, attr, params): def _rank(): def _impl(inputs, attr, params): - input_shapes = attr['_input_shapes'][inputs[0]] - assert len(inputs) == 1 + input_shape = attr['_input_shapes'][inputs[0]] name = attr["_node_name"] - params[name] = tvm.nd.array([len(input_shapes[0])]) + params[name] = tvm.nd.array([len(input_shape)]) return _sym.Variable(name=name, shape=params[name].shape) return _impl @@ -829,7 +828,7 @@ def _unpack(): def _impl(inputs, attr, params): input_node = inputs[0] axis = attr['axis'] - input_shape = attr['_input_shapes'][input_node][0] + input_shape = attr['_input_shapes'][input_node] axis_length = input_shape[axis] if axis_length < 0: raise TypeError("Unstack with unknown axis length") @@ -1018,8 +1017,8 @@ def _LSTMBlockCellWrapper(inputs, attr, params, """LSTM cell warapper to prepare the inputs""" input_shape = attr['_input_shapes'][inputs[0]] weight_shape = attr['_input_shapes'][inputs[3]] - batch_size = input_shape[0][0] - num_hidden = weight_shape[0][1] // 4 + batch_size = input_shape[0] + num_hidden = weight_shape[1] // 4 if layer == 0: #Create initial states placeholder in case of first layer @@ -1240,7 +1239,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): tensor_slot = 0 input_shape = self._output_shapes[node_name][0] inputs.append(in_sym) - input_shapes[in_sym] = [input_shape] + input_shapes[in_sym] = input_shape # This means the node is 1d in NNVM and 0d in TF. # See `_expand_dims_0d_aware`. if self._outputs_are_0d[node_name][tensor_slot] and input_shape: