Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Op] Adaptive pooling #3085

Merged
merged 13 commits into from
May 9, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/python/topi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ List of operators
topi.nn.dilate
topi.nn.pool
topi.nn.global_pool
topi.nn.adaptive_pool
topi.nn.upsampling
topi.nn.softmax
topi.nn.dense
Expand Down
4 changes: 4 additions & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ This level support backpropagation of broadcast operators. It is temporary.
tvm.relay.annotation.on_device
tvm.relay.reverse_reshape
tvm.relay.nn.batch_matmul
tvm.relay.contrib.adaptive_max_pool2d
tvm.relay.contrib.adaptive_avg_pool2d


Level 1 Definitions
Expand Down Expand Up @@ -318,3 +320,5 @@ Level 10 Definitions
.. autofunction:: tvm.relay.annotation.on_device
.. autofunction:: tvm.relay.reverse_reshape
.. autofunction:: tvm.relay.nn.batch_matmul
.. autofunction:: tvm.relay.contrib.adaptive_max_pool2d
.. autofunction:: tvm.relay.contrib.adaptive_avg_pool2d
16 changes: 16 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,22 @@ struct GlobalPool2DAttrs : public tvm::AttrsNode<GlobalPool2DAttrs> {
}
};

/*! \brief Attributes for adaptive pool operator */
struct AdaptivePool2DAttrs : public tvm::AttrsNode<AdaptivePool2DAttrs> {
Array<IndexExpr> output_size;
std::string layout;

TVM_DECLARE_ATTRS(AdaptivePool2DAttrs, "relay.attrs.AdaptivePool2DAttrs") {
TVM_ATTR_FIELD(output_size).set_default(Array<IndexExpr>({}))
.describe("Output height and width.");
TVM_ATTR_FIELD(layout).set_default("NCHW")
.describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Convolution is applied on the 'H' and"
"'W' dimensions.");
}
};


/*! \brief Attributes for dense operator */
struct DenseAttrs : public tvm::AttrsNode<DenseAttrs> {
Expand Down
4 changes: 2 additions & 2 deletions nnvm/python/nnvm/top/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def schedule_avg_pool2d(attrs, outs, target):
def schedule_global_max_pool2d(_, outs, target):
"""Schedule definition of global_max_pool2d"""
with tvm.target.create(target):
return topi.generic.schedule_global_pool(outs)
return topi.generic.schedule_adaptive_pool(outs)

reg.register_pattern("global_max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)

Expand All @@ -409,7 +409,7 @@ def schedule_global_max_pool2d(_, outs, target):
def schedule_global_avg_pool2d(_, outs, target):
"""Schedule definition of global_avg_pool2d"""
with tvm.target.create(target):
return topi.generic.schedule_global_pool(outs)
return topi.generic.schedule_adaptive_pool(outs)

reg.register_pattern("global_avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)

Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from . import nn
from . import annotation
from . import vision
from . import contrib
from . import image
from . import frontend
from . import backend
Expand Down
20 changes: 20 additions & 0 deletions python/tvm/relay/contrib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wildcard-import, unused-import, unused-wildcard-import
"""Contrib operators."""
# Re-export in a specific file name so that autodoc can pick it up
from .op.contrib import *
5 changes: 1 addition & 4 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,7 @@ def _pool2d(new_op, is_avg):

def _mx_adaptive_avg_pooling(inputs, attrs):
output_size = attrs.get_int_tuple("output_size", [])
if output_size != (1,):
raise tvm.error.OpAttributeUnimplemented(
"AdaptiveAvgPooling with output_size other than 1 is not supported yet.")
return _op.nn.global_avg_pool2d(inputs[0])
return _op.contrib.adaptive_avg_pool2d(inputs[0], output_size)


def _mx_dropout(inputs, attrs):
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from . import annotation
from . import image
from . import vision
from . import contrib
from . import op_attrs


Expand Down
21 changes: 21 additions & 0 deletions python/tvm/relay/op/contrib/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wildcard-import
"""Neural network related operators."""
from __future__ import absolute_import as _abs
from .contrib import *
from . import _contrib
43 changes: 43 additions & 0 deletions python/tvm/relay/op/contrib/_contrib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument
"""Backend compiler related feature registration"""
from __future__ import absolute_import

import topi
from .. import op as reg
from ..op import OpPattern


# adaptive_max_pool2d
@reg.register_schedule("contrib.adaptive_max_pool2d")
def schedule_adaptive_max_pool2d(_, outs, target):
"""Schedule definition of adaptive_max_pool2d"""
with target:
return topi.generic.schedule_adaptive_pool(outs)

reg.register_pattern("contrib.adaptive_max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)


# adaptive_avg_pool2d
@reg.register_schedule("contrib.adaptive_avg_pool2d")
def schedule_adaptive_avg_pool2d(_, outs, target):
"""Schedule definition of adaptive_avg_pool2d"""
with target:
return topi.generic.schedule_adaptive_pool(outs)

reg.register_pattern("contrib.adaptive_avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
20 changes: 20 additions & 0 deletions python/tvm/relay/op/contrib/_make.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Constructor APIs"""
from ...._ffi.function import _init_api

_init_api("relay.op.contrib._make", __name__)
113 changes: 113 additions & 0 deletions python/tvm/relay/op/contrib/contrib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#pylint: disable=invalid-name, too-many-lines
"""Contrib operations."""
from __future__ import absolute_import as _abs
from . import _make


def adaptive_max_pool2d(data,
output_size=None,
layout="NCHW"):
r"""2D adaptive max pooling operator. This operator is experimental.

This operator takes data as input and does 2D max value calculation
across each window represented by WxH.


In the default case, where the data_layout is `NCHW`
a data Tensor with shape `(batch_size, in_channels, height, width)`,
to produce an output Tensor with shape
(batch_size, in_channels, output_height, output_width).

The pooling kernel and stride sizes are automatically chosen for
desired output sizes.

For output_size:
If this argument is not provided, input height and width will be used
as output height and width.

If a single integer is provided for output_size, the output size is
(N x C x output_size x output_size) for any input (NCHW).

If a tuple of integers (height, width) are provided for output_size,
the output size is (N x C x height x width) for any input (NCHW).

Parameters
----------
data : tvm.relay.Expr
The input data to the operator.

output_size : tuple of int. optional
Output height and width.

layout : str, optional
Layout of the input.

Returns
-------
result : tvm.relay.Expr
The computed result.
"""
output_size = [] or output_size
return _make.adaptive_max_pool2d(data, output_size, layout)

def adaptive_avg_pool2d(data,
output_size=None,
layout="NCHW"):
r"""2D adaptive average pooling operator. This operator is experimental.

This operator takes data as input and does 2D average value calculation
across each window represented by WxH.


In the default case, where the data_layout is `NCHW`
a data Tensor with shape `(batch_size, in_channels, height, width)`,
to produce an output Tensor with shape
(batch_size, in_channels, output_height, output_width).

The pooling kernel and stride sizes are automatically chosen for
desired output sizes.

For output_size:
If this argument is not provided, input height and width will be used
as output height and width.

If a single integer is provided for output_size, the output size is
(N x C x output_size x output_size) for any input (NCHW).

If a tuple of integers (height, width) are provided for output_size,
the output size is (N x C x height x width) for any input (NCHW).

Parameters
----------
data : tvm.relay.Expr
The input data to the operator.

output_size : tuple of int. optional
Output height and width.

layout : str, optional
Layout of the input.

Returns
-------
result : tvm.relay.Expr
The computed result.
"""
output_size = [] or output_size
return _make.adaptive_avg_pool2d(data, output_size, layout)
5 changes: 3 additions & 2 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def schedule_avg_pool2d(attrs, outs, target):
def schedule_global_max_pool2d(_, outs, target):
"""Schedule definition of global_max_pool2d"""
with target:
return topi.generic.schedule_global_pool(outs)
return topi.generic.schedule_adaptive_pool(outs)


reg.register_pattern("nn.global_max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
Expand All @@ -258,11 +258,12 @@ def schedule_global_max_pool2d(_, outs, target):
def schedule_global_avg_pool2d(_, outs, target):
"""Schedule definition of global_avg_pool2d"""
with target:
return topi.generic.schedule_global_pool(outs)
return topi.generic.schedule_adaptive_pool(outs)


reg.register_pattern("nn.global_avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)


# leaky_relu
reg.register_schedule("nn.leaky_relu", schedule_broadcast)
reg.register_pattern("nn.leaky_relu", OpPattern.ELEMWISE)
Expand Down
Loading