Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Pass] Simplify consecutive transpose/layout_transform #7656

Merged
merged 5 commits into from
Mar 16, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/relay/op/make_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ Expr MakeSqueeze(Expr data, Array<Integer> axis);

Expr MakeStack(Expr data, int axis);

Expr MakeTranspose(Expr data, Array<Integer> axes);

Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides,
String slice_mode);

Expand Down
92 changes: 92 additions & 0 deletions src/relay/transforms/simplify_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,97 @@ class SimplifyReshape : public SimplifyPattern {
DFPattern x_;
};

/*!
* \brief SimplifyTranspose matches the pattern of consecutive transpose op,
* and merges or cancels them.
*/
class SimplifyTranspose : public SimplifyPattern {
public:
SimplifyTranspose() {
x_ = IsWildcard();
auto trans1 = IsOp("transpose") || IsOp("layout_transform");
auto trans2 = IsOp("transpose") || IsOp("layout_transform");
pattern_ = trans1({trans2({x_})});
}

Expr callback(const Expr& pre, const Expr& post,
const Map<DFPattern, Array<Expr>>& node_map) const override {
// Helper function to get the axes from call node attribute
auto get_axes_from_call = [](const Call trans_call, int ndim) {
std::vector<int> attr_axes;
if (auto attr = trans_call->attrs.as<TransposeAttrs>()) {
if (attr->axes.defined()) {
for (int i = 0; i < ndim; ++i) {
attr_axes.push_back(attr->axes[i]);
}
} else {
// Empty axes means reverse
for (int i = ndim - 1; i >= 0; --i) {
attr_axes.push_back(i);
}
}
} else if (auto attr = trans_call->attrs.as<LayoutTransformAttrs>()) {
Layout src_layout(attr->src_layout);
Layout dst_layout(attr->dst_layout);
for (int i = 0; i < ndim; ++i) {
attr_axes.push_back(src_layout.IndexOf(dst_layout[i]));
}
} else {
CHECK(false) << "Expected transpose or layout_transform, but got "
<< Downcast<Op>(trans_call->op)->name;
}
return std::move(attr_axes);
};

auto x = node_map[x_][0];

// Initialize axes
int ndim = Downcast<TensorType>(pre->checked_type())->shape.size();
Array<Integer> axes;
for (int i = 0; i < ndim; ++i) {
axes.push_back(i);
}

// Collect axes changes from the matched pattern, including two consecutive transposes.
std::vector<std::vector<int>> interm_axes;
Call trans_call = Downcast<Call>(post);
interm_axes.push_back(get_axes_from_call(trans_call, ndim));
trans_call = Downcast<Call>(trans_call->args[0]);
interm_axes.push_back(get_axes_from_call(trans_call, ndim));

// Calculate the final axes in reverse order (from root to output)
auto it = interm_axes.rbegin();
while (it != interm_axes.rend()) {
auto interm = *it;

Array<Integer> new_axes;
for (int i = 0; i < ndim; ++i) {
new_axes.push_back(axes[interm[i]]);
}
axes = new_axes;
it++;
}

// Check if the transpose is still required
bool need_transpose = false;
for (int i = 0; i < ndim; ++i) {
if (axes[i] != i) {
need_transpose = true;
break;
}
}

if (need_transpose) {
return MakeTranspose(x, axes);
}
return x;
}

private:
/*! \brief Pattern input */
DFPattern x_;
};

/*!
* \brief FullArgwhere finds full followed by argwhere and turns it into an Arange op
*/
Expand Down Expand Up @@ -162,6 +253,7 @@ class ExprSimplifier {
public:
explicit ExprSimplifier(IRModule mod) : mod_(mod) {
CreateCallback(SimplifyReshape());
CreateCallback(SimplifyTranspose());
CreateCallback(FullElementwise());
}
template <typename T>
Expand Down
49 changes: 49 additions & 0 deletions tests/python/relay/test_pass_simplify_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,54 @@ def symbolic():
assert tvm.ir.structural_equal(zz, after)


def test_simplify_transpose():
def before1():
x = relay.var("x", shape=(1, 3, 224, 224), dtype="float32") # NCHW
y = relay.transpose(x, axes=[0, 2, 3, 1]) # To NHWC
y = relay.layout_transform(y, "NHWC", "HWCN") # To HWCN
y = relay.transpose(y, axes=[3, 0, 1, 2]) # To NHWC
return relay.Function([x], y)

def expected1():
x = relay.var("x", shape=(1, 3, 224, 224), dtype="float32") # NCHW
y = relay.transpose(x, axes=[0, 2, 3, 1]) # To NHWC
return relay.Function([x], y)

def before2():
x = relay.var("x", shape=(1, 3, 224, 224), dtype="float32") # NCHW
y = relay.nn.relu(x)
y = relay.transpose(y, axes=[0, 2, 3, 1]) # To NHWC
y = relay.transpose(y, axes=[1, 2, 3, 0]) # To HWCN
y = relay.transpose(y, axes=[3, 2, 0, 1]) # To NCHW
return relay.Function([x], y)

def expected2():
x = relay.var("x", shape=(1, 3, 224, 224), dtype="float32") # NCHW
y = relay.nn.relu(x)
return relay.Function([x], y)

def before3():
x = relay.var("x", shape=(1, 3, 224, 224), dtype="float32") # NCHW
y = relay.nn.relu(x)
y = relay.transpose(y) # Reverse
y = relay.transpose(y) # Reverse
return relay.Function([x], y)

def expected3():
x = relay.var("x", shape=(1, 3, 224, 224), dtype="float32") # NCHW
y = relay.nn.relu(x)
return relay.Function([x], y)

for before, expected in [
[before1(), expected1()],
[before2(), expected2()],
[before3(), expected3()],
]:
after = run_opt_pass(before, transform.SimplifyExpr())
expected = run_opt_pass(expected, transform.InferType())
assert tvm.ir.structural_equal(after, expected)
comaniac marked this conversation as resolved.
Show resolved Hide resolved


def test_simplify_full_elementwise():
def validate(shape, value, dtype):
def before_left(x, elem_op, full):
Expand Down Expand Up @@ -126,4 +174,5 @@ def after_right(x, elem_op, value):

if __name__ == "__main__":
test_simplify_reshape()
test_simplify_transpose()
test_simplify_full_elementwise()