Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fuse constant padding into conv kernels #7515

Merged
merged 4 commits into from
Mar 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions src/relay/transforms/simplify_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,121 @@ class SimplifyReshape : public SimplifyPattern {
DFPattern x_;
};

/*!
* \brief SimplifyConvPad matches a pad followed by a conv/convtranspose/pool/etc
* with a pad attribute and merges the padding into the kernel.
*/
class SimplifyConvPad : public SimplifyPattern {
public:
SimplifyConvPad() {
x_ = IsWildcard();
w_ = IsWildcard();
pad_ = IsOp("nn.pad")({x_});
conv1d_ = IsOp("nn.conv1d");
conv2d_ = IsOp("nn.conv2d");
conv3d_ = IsOp("nn.conv3d");
conv_ = (conv1d_ || conv2d_ || conv3d_)({pad_, w_});
pattern_ = conv_;
}
template <typename T>
Attrs MakeConvAttrs(const T* old_attrs, const Array<PrimExpr> padding) const {
ICHECK(old_attrs);
ICHECK(padding.size() == old_attrs->padding.size())
<< "Number of dimensions to pad and convolution padding attributes should have the same "
"extent";

auto new_attrs = make_object<T>();
Array<PrimExpr> combined_padding;
for (size_t i = 0; i < padding.size(); ++i) {
combined_padding.push_back(padding[i] + old_attrs->padding[i]);
}
new_attrs->strides = old_attrs->strides;
new_attrs->padding = combined_padding;
new_attrs->dilation = old_attrs->dilation;
new_attrs->groups = old_attrs->groups;
new_attrs->channels = old_attrs->channels;
new_attrs->kernel_size = old_attrs->kernel_size;
new_attrs->data_layout = old_attrs->data_layout;
new_attrs->kernel_layout = old_attrs->kernel_layout;
new_attrs->out_layout = old_attrs->out_layout;
new_attrs->out_dtype = old_attrs->out_dtype;
return Attrs(new_attrs);
}
template <typename T>
Attrs GetAttrs(const PadAttrs* param, const T* attrs) const {
ICHECK(param);
ICHECK(attrs);
ICHECK(attrs->data_layout.size() == param->pad_width.size())
<< "Data Layout and padding attributes should have the same extent";

std::string data_layout = attrs->data_layout;
std::set<char> image_dims({'H', 'W', 'D'});
Array<PrimExpr> padding;
mbrookhart marked this conversation as resolved.
Show resolved Hide resolved
// If we're padding a non-spatial dimension, don't simplify
// Convolution can only pad on spatial axes
for (size_t i = 0; i < param->pad_width.size(); ++i) {
if (!image_dims.count(data_layout[i])) {
for (size_t j = 0; j < param->pad_width[i].size(); ++j) {
if (param->pad_width[i][j] != 0) {
return Attrs();
}
}
}
}
for (size_t j = 0; j < param->pad_width[0].size(); ++j) {
for (size_t i = 0; i < param->pad_width.size(); ++i) {
if (image_dims.count(data_layout[i])) {
padding.push_back(param->pad_width[i][j]);
}
}
}

return MakeConvAttrs(attrs, padding);
}
Expr callback(const Expr& pre, const Expr& post,
const Map<DFPattern, Array<Expr>>& node_map) const override {
const CallNode* call_node = post.as<CallNode>();
ICHECK(call_node);
auto pad = node_map[pad_][0];
const CallNode* pad_node = pad.as<CallNode>();
ICHECK(pad_node);
const PadAttrs* param = pad_node->attrs.as<PadAttrs>();
ICHECK(param);
if (param->pad_mode == "constant" && param->pad_value == 0.0) {
Attrs attrs;
if (node_map.count(conv1d_)) {
attrs = GetAttrs(param, call_node->attrs.as<Conv1DAttrs>());
} else if (node_map.count(conv2d_)) {
attrs = GetAttrs(param, call_node->attrs.as<Conv2DAttrs>());
} else if (node_map.count(conv3d_)) {
attrs = GetAttrs(param, call_node->attrs.as<Conv3DAttrs>());
} else {
return post;
}
if (!attrs.defined()) {
return post;
}
auto x = node_map[x_][0];
auto w = node_map[w_][0];
return Call(call_node->op, {x, w}, attrs, call_node->type_args, call_node->span);
}
return post;
}

private:
/*! \brief Pattern input */
DFPattern x_;
/*! \brief Pattern input weight */
DFPattern w_;
/*! \brief Pattern pad */
DFPattern pad_;
/*! \brief Pattern conv */
DFPattern conv_;
DFPattern conv1d_;
DFPattern conv2d_;
DFPattern conv3d_;
};

/*!
* \brief FullArgwhere finds full followed by argwhere and turns it into an Arange op
*/
Expand Down Expand Up @@ -163,6 +278,7 @@ class ExprSimplifier {
explicit ExprSimplifier(IRModule mod) : mod_(mod) {
CreateCallback(SimplifyReshape());
CreateCallback(FullElementwise());
CreateCallback(SimplifyConvPad());
}
template <typename T>
void CreateCallback(const T& pattern) {
Expand Down
78 changes: 78 additions & 0 deletions tests/python/relay/test_pass_simplify_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from tvm.relay import transform
from tvm.relay.testing import run_opt_pass

import numpy as np


def test_simplify_reshape():
def before():
Expand Down Expand Up @@ -122,6 +124,82 @@ def after_right(x, elem_op, value):
validate(shape, value, dtype)


def test_simplify_conv_pad():
convs = [relay.nn.conv1d, relay.nn.conv2d, relay.nn.conv3d]

def validate(ndim, pad_width, pad_value, pad_mode, orig_padding, layout):
if layout[1] == "C":
shape = [1, 3] + [10] * ndim
wshape = [8, 3] + [3] * ndim
elif layout[-1] == "C":
shape = [1] + [10] * ndim + [3]
wshape = [8] + [3] * ndim + [3]
else:
raise ValueError("This test only supports NC* and N*C")

x = relay.var("x", shape=shape, dtype="float32")
w = relay.var("w", shape=wshape, dtype="float32")
pad = relay.nn.pad(x, pad_width, pad_value, pad_mode)
if layout[1] == "C":
conv = convs[ndim - 1](pad, w, padding=orig_padding)
else:
conv = convs[ndim - 1](
pad, w, padding=orig_padding, data_layout=layout, kernel_layout="DHWIO"[3 - ndim :]
)

if pad_mode == "constant" and pad_value == 0:
new_padding = []
for j in range(2):
for i in range(len(pad_width)):
if layout[i] in ["D", "H", "W"]:
new_padding.append(pad_width[i][j])
for i in range(len(new_padding)):
new_padding[i] += orig_padding[i]
if layout[1] == "C":
after = convs[ndim - 1](x, w, padding=new_padding)
else:
after = convs[ndim - 1](
x, w, padding=new_padding, data_layout=layout, kernel_layout="DHWIO"[3 - ndim :]
)
else:
after = conv

zz = run_opt_pass(conv, transform.SimplifyExpr())
expected = run_opt_pass(after, transform.InferType())
assert tvm.ir.structural_equal(zz, expected)

mod1 = tvm.IRModule.from_expr(conv)
mod2 = tvm.IRModule.from_expr(zz)

with tvm.transform.PassContext(disabled_pass="SimplifyExpr"):
ex1 = relay.create_executor("vm", mod=mod1, ctx=tvm.cpu(), target="llvm")
ex2 = relay.create_executor("vm", mod=mod2, ctx=tvm.cpu(), target="llvm")
x_np = np.random.rand(*shape).astype("float32")
w_np = np.random.rand(*wshape).astype("float32")
result1 = ex1.evaluate()(x_np, w_np)
result2 = ex2.evaluate()(x_np, w_np)

tvm.testing.assert_allclose(result1.asnumpy(), result2.asnumpy())

for orig_pad in [[0, 0], [2, 0], [0, 2]]:
for i_pad in [[0, 0], [1, 1], [1, 0]]:
for ndim in [1, 2, 3]:
for channels_last in [0, 1]:
if channels_last:
layout = "NDHWC"
layout = layout[0:1] + layout[4 - ndim : 4] + layout[-1:]
padding = [[0, 0]] + [i_pad] * ndim + [[0, 0]]
else:
layout = "NCDHW"
layout = layout[0:2] + layout[5 - ndim :]
padding = [[0, 0]] * 2 + [i_pad] * ndim

validate(ndim, padding, 0, "constant", orig_pad * ndim, layout)
ndim = 2
validate(ndim, [[0, 0]] * 2 + [i_pad] * ndim, 1, "constant", orig_pad * ndim, "NCHW")
validate(ndim, [[0, 0]] * 2 + [i_pad] * ndim, 0, "edge", orig_pad * ndim, "NCHW")


if __name__ == "__main__":
test_simplify_reshape()
test_simplify_full_elementwise()