diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index f0ad87f65757..0e04ffdc3a23 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1666,6 +1666,16 @@ def _impl(inputs, input_types): return _impl +def _gather(): + def _impl(inputs, input_types): + data = inputs[0] + axis = inputs[1] + indices = inputs[2] + + return _op.gather(data, axis, indices) + return _impl + + def _add(prelude): # add_ is overloaded for tensor add and list concat def _impl(inputs, input_types): @@ -2030,6 +2040,7 @@ def _get_convert_map(prelude): "aten::__getitem__" : _list_getitem(prelude), "aten::len" : _list_len(prelude), "aten::type_as" : _type_as(), + "aten::gather" : _gather(), } return convert_map diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 6b731f47fc16..6a572db0bc29 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -954,6 +954,45 @@ def forward(self, *args): input_data = torch.rand(input_shape).float() verify_model(Clone1().float().eval(), input_data=input_data) + +def test_forward_gather(): + torch.set_grad_enabled(False) + + class Gather1(Module): + def forward(self, *args): + return torch.gather(args[0], 0, args[1]) + + class Gather2(Module): + def forward(self, *args): + return torch.gather(args[0], 1, args[1]) + + class Gather3(Module): + def forward(self, *args): + return torch.gather(args[0], 2, args[1]) + + input_data = torch.rand((4,)).float() + index = torch.tensor([1]) + verify_model(Gather1().float().eval(), input_data=[input_data, index]) + + input_data = torch.rand((2, 2)).float() + index = torch.tensor([[1, 0], [0, 1]]) + verify_model(Gather1().float().eval(), input_data=[input_data, index]) + + input_data = torch.tensor([[1, 2], [3, 4]]) + index = torch.tensor([[0, 0], [1, 0]]) + verify_model(Gather2().float().eval(), input_data=[input_data, index]) + + input_data = torch.rand((2, 2)).float() + index = torch.tensor([[1, 0], [0, 1]]) + verify_model(Gather2().float().eval(), input_data=[input_data, index]) + + input_data = torch.rand((3, 3, 3)).float() + index = torch.tensor([[[1, 0, 0], [1, 0, 1], [0, 1, 1]], + [[1, 1, 1], [1, 2, 1], [1, 0, 1]], + [[1, 2, 1], [1, 2, 1], [1, 2, 1]]]) + verify_model(Gather3().float().eval(), input_data=[input_data, index]) + + def test_forward_logsoftmax(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -2699,6 +2738,7 @@ def test_forward_pretrained_bert_base_uncased(): test_forward_mesh_grid() test_forward_chunk() test_forward_split() + test_forward_gather() test_upsample() test_forward_upsample3d() test_to()