From ab972bd1140989b97bed899034703f5f0ea534d4 Mon Sep 17 00:00:00 2001 From: Siva Date: Fri, 27 Sep 2024 06:20:38 +0530 Subject: [PATCH] Support for pixel_unshuffle (#34) Co-authored-by: Siva --- python/tvm/relay/frontend/pytorch.py | 26 +++++++++++++++++++ tests/python/frontend/pytorch/test_forward.py | 17 ++++++++++++ 2 files changed, 43 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 0d93ff987c6e..b8bd9fb6620b 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1696,6 +1696,31 @@ def reshape_as(self, inputs, input_types): new_shape = self.infer_shape(inputs[1]) return _op.transform.reshape(data, new_shape) + def pixel_unshuffle(self, inputs, input_types): + data = inputs[0] + downscale_factor = inputs[1] + downscale_squared = downscale_factor * downscale_factor + b, c, h, w = self.infer_shape(data) + assert h % downscale_factor == 0, "input height should be divisible by downscale_factor" + assert w % downscale_factor == 0, "input width should be divisible by downscale_factor" + + ndims = len(self.infer_shape_with_prelude(data)) + axes = list(range(ndims)) + oc = c * downscale_squared + oh = h // downscale_factor + ow = w // downscale_factor + + new_shape = [b, c, oh, downscale_factor, ow, downscale_factor] + out_shape = [b, oc, oh, ow] + + data = _op.transform.reshape(data, new_shape) + # The data will be transposed to + # [b, c, downscale_factor, downscale_factor, oh, ow] + # for further reshape + axes = [0, 1, 3, 5, 2, 4] + data = _op.transform.transpose(data, axes) + return _op.transform.reshape(data, out_shape) + def pixel_shuffle(self, inputs, input_types): data = inputs[0] upscale_factor = inputs[1] @@ -4037,6 +4062,7 @@ def create_convert_map(self): self.convert_map = { "aten::is_floating_point": self.is_floating_point, "aten::pixel_shuffle": self.pixel_shuffle, + "aten::pixel_unshuffle": self.pixel_unshuffle, "aten::device": self.none, "prim::device": self.none, "aten::sub": self.sub, diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 9f8fac93061c..ee646191e05b 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -300,6 +300,23 @@ def test_forward_pixel_shuffle(): verify_model(torch.nn.PixelShuffle(4).float().eval(), input_data=input_data) +@tvm.testing.uses_gpu +def test_forward_pixel_unshuffle(): + """test_forward_pixel_unshuffle""" + torch.set_grad_enabled(False) + input_shape = [1, 36, 32, 32] + input_data = torch.rand(input_shape).float() + verify_model(torch.nn.PixelUnshuffle(2).float().eval(), input_data=input_data) + + input_shape = [1, 16, 48, 48] + input_data = torch.rand(input_shape).float() + verify_model(torch.nn.PixelUnshuffle(3).float().eval(), input_data=input_data) + + input_shape = [1, 9, 64, 64] + input_data = torch.rand(input_shape).float() + verify_model(torch.nn.PixelUnshuffle(4).float().eval(), input_data=input_data) + + @tvm.testing.uses_gpu def test_forward_add(): """test_forward_add"""