Skip to content

Commit

Permalink
Fix pytorch axis (#14930)
Browse files Browse the repository at this point in the history
* fix conv transpose import from TF

* fix String::fromwe() to String::from()

* torch squeeze can use a list of axis

* added test for squeeze with multiple axis (pytorch 2)

* clean old code

* code without reformating

---------

Co-authored-by: Mikael Sevenier <mikael.sevenier@sima.ai>
  • Loading branch information
mikeseven and Mikael Sevenier authored May 25, 2023
1 parent 1c39613 commit bcf7abb
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
3 changes: 2 additions & 1 deletion python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,8 @@ def squeeze(self, inputs, input_types):
axis = None
else:
# TODO (t-vi): why is the cast to int needed? similarly elsewhere
axis = [int(inputs[1])]
inputs = [inputs[1]] if not isinstance(inputs[1], list) else inputs[1]
axis = [int(v) for v in inputs]

return _op.transform.squeeze(data, axis)

Expand Down
6 changes: 6 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,9 +578,15 @@ class Squeeze2(Module):
def forward(self, *args):
return args[0].squeeze(1)

class Squeeze3(Module):
def forward(self, *args):
return args[0].squeeze((1, 3))

input_data = torch.rand(input_shape).float()
verify_model(Squeeze1().float().eval(), input_data=input_data)
verify_model(Squeeze2().float().eval(), input_data=input_data)
if package_version.parse(torch.__version__) >= package_version.parse("2.0.0"):
verify_model(Squeeze3().float().eval(), input_data=input_data)


@tvm.testing.uses_gpu
Expand Down

0 comments on commit bcf7abb

Please sign in to comment.