Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch] Support index_select #6295

Merged
merged 2 commits into from
Aug 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2056,6 +2056,7 @@ def _get_convert_map(prelude):
"aten::len" : _list_len(prelude),
"aten::type_as" : _type_as(),
"aten::gather" : _gather(),
"aten::index_select" : _select(),
}
return convert_map

Expand Down
18 changes: 18 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,6 +973,7 @@ def forward(self, *args):
verify_model(View2().float().eval(), input_data=input_data)
verify_model(View3().float().eval(), input_data=input_data)


def test_forward_select():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
Expand All @@ -981,9 +982,26 @@ class Select1(Module):
def forward(self, *args):
return args[0].select(1, 1)

class IndexedSelect(Module):
def __init__(self, inp, dim):
super().__init__()
self.inp = inp
self.dim = dim
if torch.cuda.is_available():
self.inp = self.inp.cuda()

def forward(self, index):
return torch.index_select(self.inp, self.dim, index)

input_data = torch.rand(input_shape).float()
verify_model(Select1().float().eval(), input_data=input_data)

x = torch.randn(3, 4)
indices = torch.tensor([0, 2])
verify_model(IndexedSelect(x, 0).eval(), input_data=indices)
verify_model(IndexedSelect(x, 1).eval(), input_data=indices)


def test_forward_clone():
torch.set_grad_enabled(False)
input_shape = [10]
Expand Down