Skip to content

Commit

Permalink
add onnx resize v10 and unit test (apache#6726)
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew Brookhart authored and trevor-m committed Dec 4, 2020
1 parent 144b319 commit ff38e13
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 3 deletions.
23 changes: 20 additions & 3 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1870,6 +1870,25 @@ def _impl_v7(cls, inputs, attr, params):
class Resize(OnnxOpConverter):
"""Operator converter for Resize"""

@classmethod
def _impl_v10(cls, inputs, attr, params):
mode = attr.get("mode")
if mode == b"nearest":
method = "nearest_neighbor"
elif mode == b"linear":
method = "bilinear"
else:
raise tvm.error.OpAttributeInvalid(
'Value {} in attribute "mode" of operator Resize is not valid.'.format(mode)
)

scale = inputs[1]
size = _op.cast(_op.shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale

layout = "NCHW" # ONNX assumes NCHW layout
out_size = _op.strided_slice(size, [2], [4])
return _op.image.resize(inputs[0], out_size, layout, method, "asymmetric")

@classmethod
def _impl_v11(cls, inputs, attr, params):
mode = attr.get("mode")
Expand All @@ -1891,9 +1910,7 @@ def _impl_v11(cls, inputs, attr, params):
size = inputs[3]
else:
assert len(scale_shape) != 0, "One of scale or size should be passed."
size = (
_op.cast(_op.shape_of(inputs[0]), infer_type(scale).type_annotation.dtype) * scale
)
size = _op.cast(_op.shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale

coord_trans = attr.get("coordinate_transformation_mode")
if coord_trans in [b"pytorch_half_pixel", b"half_pixel"]:
Expand Down
30 changes: 30 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3525,6 +3525,36 @@ def verify(ishape, oshape, scales, mode, coord_trans):
verify([1, 16, 32, 32], [], [1, 1, 2, 2], "nearest", "asymmetric")
verify([1, 16, 32, 32], [], [1, 1, 0.5, 0.5], "linear", "half_pixel")

def verify_opset_10(ishape, scales, mode):
nodes = [
make_constant_node("scales", onnx.TensorProto.FLOAT, (len(scales),), scales),
]
input_names = ["X", "scales"]
nodes.append(
helper.make_node(
"Resize",
inputs=input_names,
outputs=["Y"],
mode=mode,
)
)

oshape = [round(dim * scale) for (dim, scale) in zip(ishape, scales)]
graph = helper.make_graph(
nodes,
"resize_test",
inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, ishape)],
outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, oshape)],
)

model = helper.make_model(graph, producer_name="resize_test")
model.opset_import[0].version = 10

verify_with_ort(model, [ishape], oshape, use_vm=True, freeze_params=True)

verify_opset_10([1, 16, 32, 32], [1, 1, 2, 2], "nearest")
verify_opset_10([1, 16, 32, 32], [1, 1, 0.5, 0.5], "linear")


@tvm.testing.uses_gpu
def test_nonzero():
Expand Down

0 comments on commit ff38e13

Please sign in to comment.