Skip to content

Commit

Permalink
Support for pixel_unshuffle (apache#34)
Browse files Browse the repository at this point in the history
Co-authored-by: Siva <quic_sivb@quicinc.com>
  • Loading branch information
srkreddy1238 committed Sep 30, 2024
1 parent 4088e23 commit 7a51055
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 0 deletions.
26 changes: 26 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1696,6 +1696,31 @@ def reshape_as(self, inputs, input_types):
new_shape = self.infer_shape(inputs[1])
return _op.transform.reshape(data, new_shape)

def pixel_unshuffle(self, inputs, input_types):
data = inputs[0]
downscale_factor = inputs[1]
downscale_squared = downscale_factor * downscale_factor
b, c, h, w = self.infer_shape(data)
assert h % downscale_factor == 0, "input height should be divisible by downscale_factor"
assert w % downscale_factor == 0, "input width should be divisible by downscale_factor"

ndims = len(self.infer_shape_with_prelude(data))
axes = list(range(ndims))
oc = c * downscale_squared
oh = h // downscale_factor
ow = w // downscale_factor

new_shape = [b, c, oh, downscale_factor, ow, downscale_factor]
out_shape = [b, oc, oh, ow]

data = _op.transform.reshape(data, new_shape)
# The data will be transposed to
# [b, c, downscale_factor, downscale_factor, oh, ow]
# for further reshape
axes = [0, 1, 3, 5, 2, 4]
data = _op.transform.transpose(data, axes)
return _op.transform.reshape(data, out_shape)

def pixel_shuffle(self, inputs, input_types):
data = inputs[0]
upscale_factor = inputs[1]
Expand Down Expand Up @@ -4037,6 +4062,7 @@ def create_convert_map(self):
self.convert_map = {
"aten::is_floating_point": self.is_floating_point,
"aten::pixel_shuffle": self.pixel_shuffle,
"aten::pixel_unshuffle": self.pixel_unshuffle,
"aten::device": self.none,
"prim::device": self.none,
"aten::sub": self.sub,
Expand Down
17 changes: 17 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,23 @@ def test_forward_pixel_shuffle():
verify_model(torch.nn.PixelShuffle(4).float().eval(), input_data=input_data)


@tvm.testing.uses_gpu
def test_forward_pixel_unshuffle():
"""test_forward_pixel_unshuffle"""
torch.set_grad_enabled(False)
input_shape = [1, 36, 32, 32]
input_data = torch.rand(input_shape).float()
verify_model(torch.nn.PixelUnshuffle(2).float().eval(), input_data=input_data)

input_shape = [1, 16, 48, 48]
input_data = torch.rand(input_shape).float()
verify_model(torch.nn.PixelUnshuffle(3).float().eval(), input_data=input_data)

input_shape = [1, 9, 64, 64]
input_data = torch.rand(input_shape).float()
verify_model(torch.nn.PixelUnshuffle(4).float().eval(), input_data=input_data)


@tvm.testing.uses_gpu
def test_forward_add():
"""test_forward_add"""
Expand Down

0 comments on commit 7a51055

Please sign in to comment.