Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[relay][simplify_expr]: Add pass to remove trivial transpose ops #14858

Merged
merged 1 commit into from
May 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 75 additions & 36 deletions src/relay/transforms/simplify_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,38 @@ class SimplifyCastClip : public DFPatternRewrite {
DFPattern clip_, cast_;
};

/*!
* \brief Return the axis order for layout transform and transpose
* ops.
*/
static std::vector<int> GetTransposeAxisOrder(const Call& call, int ndim) {
std::vector<int> attr_axes;
if (auto attr = call->attrs.as<TransposeAttrs>()) {
if (attr->axes.defined()) {
for (int i = 0; i < ndim; ++i) {
int64_t axis = attr->axes[i].IntValue();
axis += (axis < 0) ? ndim : 0;
attr_axes.push_back(axis);
}
} else {
// Empty axes means reverse
for (int i = ndim - 1; i >= 0; --i) {
attr_axes.push_back(i);
}
}
} else if (auto attr = call->attrs.as<LayoutTransformAttrs>()) {
Layout src_layout(attr->src_layout);
Layout dst_layout(attr->dst_layout);
for (int i = 0; i < ndim; ++i) {
attr_axes.push_back(src_layout.IndexOf(dst_layout[i]));
}
} else {
CHECK(false) << "Expected transpose or layout_transform, but got "
<< Downcast<Op>(call->op)->name;
}
return std::move(attr_axes);
}

/*!
* \brief SimplifyTranspose matches the pattern of consecutive transpose op,
* and merges or cancels them.
Expand Down Expand Up @@ -316,19 +348,7 @@ class SimplifyTranspose : public DFPatternRewrite {
it++;
}

// Check if the transpose is still required
bool need_transpose = false;
for (int i = 0; i < ndim; ++i) {
if (axes[i] != i) {
need_transpose = true;
break;
}
}

if (need_transpose) {
return MakeTranspose(x, axes);
}
return x;
return MakeTranspose(x, axes);
}

String PermuteLayout(const String& layout, std::vector<int> axes_order) const {
Expand Down Expand Up @@ -431,32 +451,50 @@ class SimplifyTranspose : public DFPatternRewrite {
return Downcast<Call>(output_layout_trans);
}

std::vector<int> GetTransposeAxisOrder(const Call& call, int ndim) const {
std::vector<int> attr_axes;
if (auto attr = call->attrs.as<TransposeAttrs>()) {
if (attr->axes.defined()) {
for (int i = 0; i < ndim; ++i) {
int64_t axis = attr->axes[i].IntValue();
axis += (axis < 0) ? ndim : 0;
attr_axes.push_back(axis);
}
} else {
// Empty axes means reverse
for (int i = ndim - 1; i >= 0; --i) {
attr_axes.push_back(i);
}
private:
/*! \brief Pattern input */
DFPattern x_;
};

/*!
* \brief SimplifyNoOpTranspose matches the pattern of transpose or
* layout transform ops which do not change the layout or rank and
* removes the op.
*/
class SimplifyNoOpTranspose : public DFPatternRewrite {
public:
SimplifyNoOpTranspose() {
x_ = IsWildcard();
auto trans1 = IsOp("transpose") || IsOp("layout_transform");
pattern_ = trans1({x_});
}

Expr Callback(const Expr& pre, const Expr& post,
const Map<DFPattern, Array<Expr>>& node_map) const override {
auto x = node_map[x_][0];
Call trans_call = Downcast<Call>(post);

// Do not remove ops which change rank
if (auto attr = trans_call->attrs.as<LayoutTransformAttrs>()) {
if (attr->src_layout != attr->dst_layout) {
return post;
}
} else if (auto attr = call->attrs.as<LayoutTransformAttrs>()) {
Layout src_layout(attr->src_layout);
Layout dst_layout(attr->dst_layout);
for (int i = 0; i < ndim; ++i) {
attr_axes.push_back(src_layout.IndexOf(dst_layout[i]));
}

int ndim = Downcast<TensorType>(pre->checked_type())->shape.size();
auto axes = GetTransposeAxisOrder(trans_call, ndim);

bool need_transpose = false;
for (int i = 0; i < ndim; ++i) {
if (axes[i] != i) {
need_transpose = true;
break;
}
} else {
CHECK(false) << "Expected transpose or layout_transform, but got "
<< Downcast<Op>(call->op)->name;
}
return std::move(attr_axes);

if (!need_transpose) return x;

return post;
}

private:
Expand Down Expand Up @@ -1037,6 +1075,7 @@ Expr SimplifyExpr(const Expr& expr, const IRModule& mod) {
composer.AddRewrite<EliminateIdentityRewrite>();
composer.AddRewrite<SimplifyReshape>();
composer.AddRewrite<SimplifyTranspose>();
composer.AddRewrite<SimplifyNoOpTranspose>();
composer.AddRewrite<SimplifySameCast>();
composer.AddRewrite<SimplifyConsecutiveCast>();
composer.AddRewrite<FullElementwise>();
Expand Down
22 changes: 22 additions & 0 deletions tests/python/relay/test_pass_simplify_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,27 @@ def expected10():
y = relay.nn.relu(y)
return relay.Function([x], y)

def before11():
"""
Remove trivial no op transpose ops

Input:
op1 -> relay.transpose(x, axes=[0, 1, 2, 3]) -> op2

Simplified:
op1 -> op2
"""
x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32")
y = relay.transpose(x, axes=[0, 1, 2, 3])
y = relay.nn.relu(y)
y = relay.layout_transform(y, "NCHW", "NCHW")
return relay.Function([x], y)

def expected11():
x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32")
y = relay.nn.relu(x)
return relay.Function([x], y)

for before, expected in [
[before1(), expected1()],
[before2(), expected2()],
Expand All @@ -277,6 +298,7 @@ def expected10():
[before8(), expected8()],
[before9(), expected9()],
[before10(), expected10()],
[before11(), expected11()],
]:
after = run_opt_pass(before, transform.SimplifyExpr())
expected = run_opt_pass(expected, transform.InferType())
Expand Down