Skip to content

Commit

Permalink
Support aten::flip (#8398)
Browse files Browse the repository at this point in the history
* Support test aten::flip

* Support aten::flip
  • Loading branch information
delldu authored Jul 4, 2021
1 parent a00d211 commit d17f753
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
6 changes: 6 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2322,6 +2322,11 @@ def nll_loss(self, inputs, input_types):
weights = _op.full(_expr.const(1), (num_class,), dtype=input_types[0])
return _op.nn.nll_loss(predictions, targets, weights, reduction, ignore_index)

def flip(self, inputs, input_types):
data = inputs[0]
axis = inputs[1]
return _op.transform.reverse(data, axis=axis[0])

# Operator mappings
def create_convert_map(self):
self.convert_map = {
Expand Down Expand Up @@ -2536,6 +2541,7 @@ def create_convert_map(self):
"aten::_unique2": self.unique,
"aten::nll_loss": self.nll_loss,
"aten::nll_loss2d": self.nll_loss,
"aten::flip": self.flip,
}

def update_convert_map(self, custom_map):
Expand Down
20 changes: 20 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3893,6 +3893,25 @@ def test_forward_nll_loss():
verify_model(torch.nn.NLLLoss(reduction="none").eval(), input_data=[predictions, targets])


@tvm.testing.uses_gpu
def test_forward_flip():
torch.set_grad_enabled(False)

class Flip(Module):
def __init__(self, axis=0):
super().__init__()
self.axis = axis

def forward(self, x):
return x.flip([self.axis])

input = torch.randn(2, 3, 4)
verify_model(Flip(axis=0), input_data=input)
verify_model(Flip(axis=1), input_data=input)
verify_model(Flip(axis=2), input_data=input)
verify_model(Flip(axis=-1), input_data=input)


if __name__ == "__main__":
# some structural tests
test_forward_traced_function()
Expand Down Expand Up @@ -4035,6 +4054,7 @@ def test_forward_nll_loss():
test_hard_swish()
test_hard_sigmoid()
test_forward_nll_loss()
test_forward_flip()

# Model tests
test_resnet18()
Expand Down

0 comments on commit d17f753

Please sign in to comment.