Skip to content

Commit

Permalink
[TORCH] Implement avg_pool1d
Browse files Browse the repository at this point in the history
  • Loading branch information
cgerum committed Mar 18, 2021
1 parent 38aed59 commit 33fe25d
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 2 deletions.
26 changes: 26 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1353,6 +1353,31 @@ def softplus(self, inputs, input_types):
beta = _expr.const(float(inputs[1]), dtype=dtype)
return _op.log(_op.exp(inputs[0] * beta) + _expr.const(1.0, dtype=dtype)) / beta

def avg_pool1d(self, inputs, input_types):
data = inputs[0]

pool_size = self.convert_const_list(inputs[1])
strides = self.convert_const_list(inputs[2] if inputs[2] else pool_size)
padding = inputs[3]
ceil_mode = int(inputs[4])
count_include_pad = int(inputs[5])

def func(x):
return _op.nn.avg_pool1d(
x,
pool_size=pool_size,
strides=strides,
padding=padding,
ceil_mode=ceil_mode,
count_include_pad=count_include_pad,
)

if self.is_quantized_tensor(data):
return qnn_torch.apply_with_upcast(data, func)

return func(data)


def avg_pool2d(self, inputs, input_types):
data = inputs[0]

Expand Down Expand Up @@ -2338,6 +2363,7 @@ def create_convert_map(self):
"aten::log_softmax": self.log_softmax,
"aten::sigmoid": self.sigmoid,
"aten::softplus": self.softplus,
"aten::avg_pool1d": self.avg_pool1d,
"aten::avg_pool2d": self.avg_pool2d,
"aten::avg_pool3d": self.avg_pool3d,
"aten::linear": self.linear,
Expand Down
19 changes: 17 additions & 2 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,7 +809,21 @@ def forward(self, *args):


@tvm.testing.uses_gpu
def test_forward_avgpool():
def test_forward_avgpool1d():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10]

class AvgPool1D2(Module):
def forward(self, *args):
return torch.nn.functional.avg_pool1d(args[0], kernel_size=[10])

input_data = torch.rand(input_shape).float()
verify_model(torch.nn.AvgPool1d(kernel_size=[10]).eval(), input_data=input_data)
verify_model(AvgPool1D2().float().eval(), input_data=input_data)


@tvm.testing.uses_gpu
def test_forward_avgpool2d():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]

Expand Down Expand Up @@ -3831,7 +3845,8 @@ def test_fn(is_sorted, return_inverse, return_counts):
test_forward_logsoftmax()
test_forward_sigmoid()
test_forward_dense()
test_forward_avgpool()
test_forward_avgpool1d()
test_forward_avgpool2d()
test_forward_avgpool3d()
test_forward_dropout()
test_forward_slice()
Expand Down

0 comments on commit 33fe25d

Please sign in to comment.