From 5e9bfc0b62e365488558504b66399c0f31db998a Mon Sep 17 00:00:00 2001 From: Krishna Bindumadhavan Date: Mon, 15 May 2023 22:17:59 +0000 Subject: [PATCH] [relay][simplify_expr]: Add pattern to remove trivial transpose ops --- src/relay/transforms/simplify_expr.cc | 111 ++++++++++++------ tests/python/relay/test_pass_simplify_expr.py | 22 ++++ 2 files changed, 97 insertions(+), 36 deletions(-) diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index 4f255647df87..a557f2496b58 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -260,6 +260,38 @@ class SimplifyCastClip : public DFPatternRewrite { DFPattern clip_, cast_; }; +/*! + * \brief Return the axis order for layout transform and transpose + * ops. + */ +static std::vector GetTransposeAxisOrder(const Call& call, int ndim) { + std::vector attr_axes; + if (auto attr = call->attrs.as()) { + if (attr->axes.defined()) { + for (int i = 0; i < ndim; ++i) { + int64_t axis = attr->axes[i].IntValue(); + axis += (axis < 0) ? ndim : 0; + attr_axes.push_back(axis); + } + } else { + // Empty axes means reverse + for (int i = ndim - 1; i >= 0; --i) { + attr_axes.push_back(i); + } + } + } else if (auto attr = call->attrs.as()) { + Layout src_layout(attr->src_layout); + Layout dst_layout(attr->dst_layout); + for (int i = 0; i < ndim; ++i) { + attr_axes.push_back(src_layout.IndexOf(dst_layout[i])); + } + } else { + CHECK(false) << "Expected transpose or layout_transform, but got " + << Downcast(call->op)->name; + } + return std::move(attr_axes); +} + /*! * \brief SimplifyTranspose matches the pattern of consecutive transpose op, * and merges or cancels them. @@ -316,19 +348,7 @@ class SimplifyTranspose : public DFPatternRewrite { it++; } - // Check if the transpose is still required - bool need_transpose = false; - for (int i = 0; i < ndim; ++i) { - if (axes[i] != i) { - need_transpose = true; - break; - } - } - - if (need_transpose) { - return MakeTranspose(x, axes); - } - return x; + return MakeTranspose(x, axes); } String PermuteLayout(const String& layout, std::vector axes_order) const { @@ -431,32 +451,50 @@ class SimplifyTranspose : public DFPatternRewrite { return Downcast(output_layout_trans); } - std::vector GetTransposeAxisOrder(const Call& call, int ndim) const { - std::vector attr_axes; - if (auto attr = call->attrs.as()) { - if (attr->axes.defined()) { - for (int i = 0; i < ndim; ++i) { - int64_t axis = attr->axes[i].IntValue(); - axis += (axis < 0) ? ndim : 0; - attr_axes.push_back(axis); - } - } else { - // Empty axes means reverse - for (int i = ndim - 1; i >= 0; --i) { - attr_axes.push_back(i); - } + private: + /*! \brief Pattern input */ + DFPattern x_; +}; + +/*! + * \brief SimplifyNoOpTranspose matches the pattern of transpose or + * layout transform ops which do not change the layout or rank and + * removes the op. + */ +class SimplifyNoOpTranspose : public DFPatternRewrite { + public: + SimplifyNoOpTranspose() { + x_ = IsWildcard(); + auto trans1 = IsOp("transpose") || IsOp("layout_transform"); + pattern_ = trans1({x_}); + } + + Expr Callback(const Expr& pre, const Expr& post, + const Map>& node_map) const override { + auto x = node_map[x_][0]; + Call trans_call = Downcast(post); + + // Do not remove ops which change rank + if (auto attr = trans_call->attrs.as()) { + if (attr->src_layout != attr->dst_layout) { + return post; } - } else if (auto attr = call->attrs.as()) { - Layout src_layout(attr->src_layout); - Layout dst_layout(attr->dst_layout); - for (int i = 0; i < ndim; ++i) { - attr_axes.push_back(src_layout.IndexOf(dst_layout[i])); + } + + int ndim = Downcast(pre->checked_type())->shape.size(); + auto axes = GetTransposeAxisOrder(trans_call, ndim); + + bool need_transpose = false; + for (int i = 0; i < ndim; ++i) { + if (axes[i] != i) { + need_transpose = true; + break; } - } else { - CHECK(false) << "Expected transpose or layout_transform, but got " - << Downcast(call->op)->name; } - return std::move(attr_axes); + + if (!need_transpose) return x; + + return post; } private: @@ -1037,6 +1075,7 @@ Expr SimplifyExpr(const Expr& expr, const IRModule& mod) { composer.AddRewrite(); composer.AddRewrite(); composer.AddRewrite(); + composer.AddRewrite(); composer.AddRewrite(); composer.AddRewrite(); composer.AddRewrite(); diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py index d11242dbd86b..4edb85f2d793 100644 --- a/tests/python/relay/test_pass_simplify_expr.py +++ b/tests/python/relay/test_pass_simplify_expr.py @@ -266,6 +266,27 @@ def expected10(): y = relay.nn.relu(y) return relay.Function([x], y) + def before11(): + """ + Remove trivial no op transpose ops + + Input: + op1 -> relay.transpose(x, axes=[0, 1, 2, 3]) -> op2 + + Simplified: + op1 -> op2 + """ + x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32") + y = relay.transpose(x, axes=[0, 1, 2, 3]) + y = relay.nn.relu(y) + y = relay.layout_transform(y, "NCHW", "NCHW") + return relay.Function([x], y) + + def expected11(): + x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32") + y = relay.nn.relu(x) + return relay.Function([x], y) + for before, expected in [ [before1(), expected1()], [before2(), expected2()], @@ -277,6 +298,7 @@ def expected10(): [before8(), expected8()], [before9(), expected9()], [before10(), expected10()], + [before11(), expected11()], ]: after = run_opt_pass(before, transform.SimplifyExpr()) expected = run_opt_pass(expected, transform.InferType())