From 5aa2ed5d128a7a97cac8958ccb7833761cac3bfb Mon Sep 17 00:00:00 2001 From: masa Date: Tue, 18 Aug 2020 10:49:09 +0900 Subject: [PATCH 1/2] support index select --- python/tvm/relay/frontend/pytorch.py | 2 +- tests/python/frontend/pytorch/test_forward.py | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index a1cabcd5ae22..ebd60469d137 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2056,6 +2056,7 @@ def _get_convert_map(prelude): "aten::len" : _list_len(prelude), "aten::type_as" : _type_as(), "aten::gather" : _gather(), + "aten::index_select" : _select(), } return convert_map @@ -2666,5 +2667,4 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None, default_d default_dtype=default_dtype) mod["main"] = tvm.relay.Function(_analysis.free_vars(ret[0]), ret[0]) - return mod, tvm_params diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 88203f560641..2302f0fdb74a 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -973,6 +973,7 @@ def forward(self, *args): verify_model(View2().float().eval(), input_data=input_data) verify_model(View3().float().eval(), input_data=input_data) + def test_forward_select(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -981,9 +982,26 @@ class Select1(Module): def forward(self, *args): return args[0].select(1, 1) + class IndexedSelect(Module): + def __init__(self, inp, dim): + super().__init__() + self.inp = inp + self.dim = dim + if torch.cuda.is_available(): + self.inp = self.inp.cuda() + + def forward(self, index): + return torch.index_select(self.inp, self.dim, index) + input_data = torch.rand(input_shape).float() verify_model(Select1().float().eval(), input_data=input_data) + x = torch.randn(3, 4) + indices = torch.tensor([0, 2]) + verify_model(IndexedSelect(x, 0).eval(), input_data=indices) + verify_model(IndexedSelect(x, 1).eval(), input_data=indices) + + def test_forward_clone(): torch.set_grad_enabled(False) input_shape = [10] From 6be43b875f637e63d977b2500bf76fe7120f3823 Mon Sep 17 00:00:00 2001 From: masa Date: Tue, 18 Aug 2020 10:54:23 +0900 Subject: [PATCH 2/2] minor fix --- python/tvm/relay/frontend/pytorch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index ebd60469d137..235cec0f096d 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2667,4 +2667,5 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None, default_d default_dtype=default_dtype) mod["main"] = tvm.relay.Function(_analysis.free_vars(ret[0]), ret[0]) + return mod, tvm_params