Skip to content

Commit

Permalink
[PYTORCH]Gather op support added (#6013)
Browse files Browse the repository at this point in the history
* [PYTORCH]Gather op support added

* retrigger
  • Loading branch information
siju-samuel authored Jul 9, 2020
1 parent bfe73b2 commit 8887b0f
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 0 deletions.
11 changes: 11 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1666,6 +1666,16 @@ def _impl(inputs, input_types):
return _impl


def _gather():
def _impl(inputs, input_types):
data = inputs[0]
axis = inputs[1]
indices = inputs[2]

return _op.gather(data, axis, indices)
return _impl


def _add(prelude):
# add_ is overloaded for tensor add and list concat
def _impl(inputs, input_types):
Expand Down Expand Up @@ -2030,6 +2040,7 @@ def _get_convert_map(prelude):
"aten::__getitem__" : _list_getitem(prelude),
"aten::len" : _list_len(prelude),
"aten::type_as" : _type_as(),
"aten::gather" : _gather(),
}
return convert_map

Expand Down
40 changes: 40 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,6 +954,45 @@ def forward(self, *args):
input_data = torch.rand(input_shape).float()
verify_model(Clone1().float().eval(), input_data=input_data)


def test_forward_gather():
torch.set_grad_enabled(False)

class Gather1(Module):
def forward(self, *args):
return torch.gather(args[0], 0, args[1])

class Gather2(Module):
def forward(self, *args):
return torch.gather(args[0], 1, args[1])

class Gather3(Module):
def forward(self, *args):
return torch.gather(args[0], 2, args[1])

input_data = torch.rand((4,)).float()
index = torch.tensor([1])
verify_model(Gather1().float().eval(), input_data=[input_data, index])

input_data = torch.rand((2, 2)).float()
index = torch.tensor([[1, 0], [0, 1]])
verify_model(Gather1().float().eval(), input_data=[input_data, index])

input_data = torch.tensor([[1, 2], [3, 4]])
index = torch.tensor([[0, 0], [1, 0]])
verify_model(Gather2().float().eval(), input_data=[input_data, index])

input_data = torch.rand((2, 2)).float()
index = torch.tensor([[1, 0], [0, 1]])
verify_model(Gather2().float().eval(), input_data=[input_data, index])

input_data = torch.rand((3, 3, 3)).float()
index = torch.tensor([[[1, 0, 0], [1, 0, 1], [0, 1, 1]],
[[1, 1, 1], [1, 2, 1], [1, 0, 1]],
[[1, 2, 1], [1, 2, 1], [1, 2, 1]]])
verify_model(Gather3().float().eval(), input_data=[input_data, index])


def test_forward_logsoftmax():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
Expand Down Expand Up @@ -2699,6 +2738,7 @@ def test_forward_pretrained_bert_base_uncased():
test_forward_mesh_grid()
test_forward_chunk()
test_forward_split()
test_forward_gather()
test_upsample()
test_forward_upsample3d()
test_to()
Expand Down

0 comments on commit 8887b0f

Please sign in to comment.