Skip to content

Commit

Permalink
[Relay][Frontend][Onnx] Add support for Size op in Onnx frontend. (#7031
Browse files Browse the repository at this point in the history
)

* Add support for Size op in Onnx frontend.

* Simplify target testing.
  • Loading branch information
jwfromm authored Dec 4, 2020
1 parent f278c42 commit c8397bf
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
1 change: 1 addition & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2374,6 +2374,7 @@ def _get_convert_map(opset):
"Gather": Gather.get_converter(opset),
"GatherElements": GatherElements.get_converter(opset),
"GatherND": GatherND.get_converter(opset),
"Size": AttrCvt("ndarray_size", extras={"dtype": "int64"}),
"Scatter": Scatter.get_converter(opset),
"ScatterElements": Scatter.get_converter(opset),
"Squeeze": AttrCvt("squeeze", {"axes": "axis"}),
Expand Down
28 changes: 28 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3888,6 +3888,33 @@ def test_if():
tvm.testing.assert_allclose(correct_out[i], tvm_out[i], rtol=1e-05, atol=1e-05)


@tvm.testing.uses_gpu
def test_size():
def verify_size(indata):
node = helper.make_node(
"Size",
inputs=["X"],
outputs=["Y"],
)

graph = helper.make_graph(
[node],
"size_test",
inputs=[helper.make_tensor_value_info("X", TensorProto.INT64, list(indata.shape))],
outputs=[helper.make_tensor_value_info("Y", TensorProto.INT64, [])],
)

model = helper.make_model(graph, producer_name="size_test")

verify_with_ort_with_inputs(model, [indata], dtype="int64", use_vm=True, opset=11)

input_data = np.array([[1, 0], [1, 1]], dtype=np.int64)
verify_size(input_data)

input_data = np.array([[3, 0, 0], [0, 4, 0], [5, 6, 0]], dtype=np.int64)
verify_size(input_data)


if __name__ == "__main__":
test_flatten()
test_reshape()
Expand Down Expand Up @@ -3964,3 +3991,4 @@ def test_if():
test_roi_align()
test_range()
test_loop()
test_size()

0 comments on commit c8397bf

Please sign in to comment.