Skip to content

Commit

Permalink
[relay][op] add expand op (from ONNX) to relay frontend (apache#4483)
Browse files Browse the repository at this point in the history
* Add Expand to onnx.py

* add test function for expand

* Fix a onnx frontend test

* Add tests for the value itself instead of shape only on test_expand

* Cleaned up some unnecessary modifications.
  • Loading branch information
Takato Yamada authored and zhiics committed Jan 11, 2020
1 parent e184ef4 commit 403174f
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 0 deletions.
47 changes: 47 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1080,6 +1080,52 @@ class Or(Elemwise):
def _impl_v7(cls, inputs, attr, params):
return _op.logical_or(inputs[0], inputs[1])

class Expand(OnnxOpConverter):
""" Operator converter for Expand.
"""
@classmethod
def _impl_v8(cls, inputs, attr, params):
in_shape = np.array(infer_shape(inputs[0])).astype('int32')
if get_name(inputs[1]) in params:
shape = params[inputs[1].name_hint].asnumpy().astype('int32')
else:
shape = infer_value_simulated(inputs[1], params).asnumpy().astype('int32')

# Currently 'op.broadcast_to' expect the rank of the given 'shape'
# (the 2nd input) is always higher than that of the given 'input' (the 1st input)
# However, ONNX Expand supports multi-directional broadcasting, which allows
# above pattern and also some extent of 'shape' can be smaller than the corresponding
# extent of 'input'. In this case, the extent of 'shape' must be 1.
# https://github.com/onnx/onnx/blob/master/docs/Broadcasting.md
# In above cases, we cannot directorly apply 'op.broadcast_to' instead of 'expand'
# so, here we solved this problem by expanding the given 'shape' itself.
def expand_shape(in_shape, shape):
""" A function expands the shape when the rank is lower than that of the given
intput. Also it replaces the extent of the shape with the corresponding extent
of the intput when it is 1.
"""

# here we flip the shapes because this can be more simply written
# when the innermost dimension is located at the index 0.
in_shape = np.flip(in_shape, axis=0)
shape = np.flip(shape, axis=0)

if in_shape.size < shape.size:
for i in range(shape.size):
if i < in_shape.size and in_shape[i] > shape[i]:
shape[i] = in_shape[i]
else:
for i in range(in_shape.size):
if i >= shape.size:
np.append(shape, in_shape[i])
elif shape[i] == 1:
shape[i] = in_shape[i]

new_shape = np.flip(shape, axis=0)
return new_shape

shape = expand_shape(in_shape, shape)
return _op.broadcast_to(inputs[0], shape=tuple(shape))

# compatible operators that do NOT require any conversion.
_identity_list = []
Expand Down Expand Up @@ -1187,6 +1233,7 @@ def _get_convert_map(opset):
# defs/tensor
'Cast': Cast.get_converter(opset),
'Reshape': Reshape.get_converter(opset),
'Expand': Expand.get_converter(opset),
'Concat': Concat.get_converter(opset),
'Split': Split.get_converter(opset),
'Slice': Slice.get_converter(opset),
Expand Down
41 changes: 41 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,46 @@ def test_reshape():
tvm.testing.assert_allclose(ref_shape, tvm_out.shape)


def test_expand():

def _test_expand(name, data, shape, ref_data):
shape_array = np.array(shape)
shape_node = onnx.helper.make_node('Constant',
inputs=[],
outputs=['shape'],
value=onnx.helper.make_tensor(name = 'const_tensor',
data_type = onnx.TensorProto.INT32,
dims = shape_array.shape,
vals = shape_array.flatten().astype('int32')))
expand_node = helper.make_node("Expand", ["in", "shape"], ["out"])

graph = helper.make_graph([shape_node, expand_node],
"expand_test",
inputs = [helper.make_tensor_value_info("in",
TensorProto.FLOAT, list(data.shape))],
outputs = [helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(ref_data.shape))])

model = helper.make_model(graph, producer_name=name)

for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, data, target, ctx, ref_data.shape, 'float32')

tvm.testing.assert_allclose(ref_data, tvm_out)

in_shape = (3, 1)
shape = (3, 4)
data = np.random.uniform(size=in_shape).astype(np.float32)
ref_data = np.tile(data, 4)
_test_expand('expand_with_dim_unchanged_test', data, shape, ref_data)

in_shape = (3, 1)
shape = (2, 1, 6)
data = np.random.uniform(size=in_shape).astype(np.float32)
ref_data = data * np.ones(shape, dtype=np.float32)
_test_expand('expand_with_dim_changed_test', data, shape, ref_data)


def verify_depth_to_space(inshape, outshape, mode, blockSize):
node = onnx.helper.make_node('DepthToSpace',
inputs=['x'],
Expand Down Expand Up @@ -1710,6 +1750,7 @@ def test_or():
test_flatten()
test_reshape()
test_shape()
test_expand()
test_power()
test_squeeze()
test_unsqueeze()
Expand Down

0 comments on commit 403174f

Please sign in to comment.