Skip to content

Commit

Permalink
[TORCH] Make format checks happy with unified avg_pool
Browse files Browse the repository at this point in the history
  • Loading branch information
cgerum committed Mar 18, 2021
1 parent 2dd497a commit 9e71fbd
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
1 change: 0 additions & 1 deletion python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1354,7 +1354,6 @@ def softplus(self, inputs, input_types):
return _op.log(_op.exp(inputs[0] * beta) + _expr.const(1.0, dtype=dtype)) / beta

def make_avg_pool(self, dim):

def avg_pool(inputs, input_types):
data = inputs[0]

Expand Down
12 changes: 9 additions & 3 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,7 +820,9 @@ def forward(self, *args):
input_data = torch.rand(input_shape).float()
verify_model(torch.nn.AvgPool1d(kernel_size=[10]).eval(), input_data=input_data)
verify_model(AvgPool1D2().float().eval(), input_data=input_data)
verify_model(torch.nn.AvgPool1d(kernel_size=[5], stride=2, padding=2).eval(), input_data=input_data)
verify_model(
torch.nn.AvgPool1d(kernel_size=[5], stride=2, padding=2).eval(), input_data=input_data
)


@tvm.testing.uses_gpu
Expand All @@ -835,7 +837,9 @@ def forward(self, *args):
input_data = torch.rand(input_shape).float()
verify_model(torch.nn.AvgPool2d(kernel_size=[10, 10]).eval(), input_data=input_data)
verify_model(AvgPool2D2().float().eval(), input_data=input_data)
verify_model(torch.nn.AvgPool2d(kernel_size=5, stride=2, padding=2).eval(), input_data=input_data)
verify_model(
torch.nn.AvgPool2d(kernel_size=5, stride=2, padding=2).eval(), input_data=input_data
)


@tvm.testing.uses_gpu
Expand All @@ -850,7 +854,9 @@ def forward(self, *args):
input_data = torch.rand(input_shape).float()
verify_model(torch.nn.AvgPool3d(kernel_size=[10, 10, 10]).eval(), input_data=input_data)
verify_model(AvgPool3D1().float().eval(), input_data=input_data)
verify_model(torch.nn.AvgPool3d(kernel_size=5, stride=2, padding=2).eval(), input_data=input_data)
verify_model(
torch.nn.AvgPool3d(kernel_size=5, stride=2, padding=2).eval(), input_data=input_data
)


@tvm.testing.uses_gpu
Expand Down

0 comments on commit 9e71fbd

Please sign in to comment.