From 333119c9229b4f63fd6d11e43cca63a9a49cf1c8 Mon Sep 17 00:00:00 2001 From: Dell Du <18588220928@163.com> Date: Sat, 3 Jul 2021 22:57:37 +0800 Subject: [PATCH 1/2] Support test aten::flip --- tests/python/frontend/pytorch/test_forward.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 2ec281094080..f76ea9a5d324 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3893,6 +3893,25 @@ def test_forward_nll_loss(): verify_model(torch.nn.NLLLoss(reduction="none").eval(), input_data=[predictions, targets]) +@tvm.testing.uses_gpu +def test_forward_flip(): + torch.set_grad_enabled(False) + + class Flip(Module): + def __init__(self, axis=0): + super().__init__() + self.axis = axis + + def forward(self, x): + return x.flip([self.axis]) + + input = torch.randn(2, 3, 4) + verify_model(Flip(axis=0), input_data=input) + verify_model(Flip(axis=1), input_data=input) + verify_model(Flip(axis=2), input_data=input) + verify_model(Flip(axis=-1), input_data=input) + + if __name__ == "__main__": # some structural tests test_forward_traced_function() @@ -4035,6 +4054,7 @@ def test_forward_nll_loss(): test_hard_swish() test_hard_sigmoid() test_forward_nll_loss() + test_forward_flip() # Model tests test_resnet18() From 0ffb31be090f05e79295795db4262b4c26a3db43 Mon Sep 17 00:00:00 2001 From: Dell Du <18588220928@163.com> Date: Sat, 3 Jul 2021 22:58:27 +0800 Subject: [PATCH 2/2] Support aten::flip --- python/tvm/relay/frontend/pytorch.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 00fa9f597d06..aa0217db046c 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2325,6 +2325,11 @@ def nll_loss(self, inputs, input_types): weights = _op.full(_expr.const(1), (num_class,), dtype=input_types[0]) return _op.nn.nll_loss(predictions, targets, weights, reduction, ignore_index) + def flip(self, inputs, input_types): + data = inputs[0] + axis = inputs[1] + return _op.transform.reverse(data, axis=axis[0]) + # Operator mappings def create_convert_map(self): self.convert_map = { @@ -2539,6 +2544,7 @@ def create_convert_map(self): "aten::_unique2": self.unique, "aten::nll_loss": self.nll_loss, "aten::nll_loss2d": self.nll_loss, + "aten::flip": self.flip, } def update_convert_map(self, custom_map):