From 396a09e06441024f5b95dcf6762745368cf9d8e6 Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Tue, 4 May 2021 01:05:41 -0700 Subject: [PATCH] [Relay][Pass] Update SimplifyTranspose to correctly simplify rank changing layout transforms (#7807) --- src/relay/transforms/simplify_expr.cc | 175 ++++++++++++++---- tests/python/relay/test_pass_simplify_expr.py | 166 +++++++++++++++++ 2 files changed, 310 insertions(+), 31 deletions(-) diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index 5662ef5b45a6..fb7a76f1ea7a 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -31,6 +31,8 @@ #include #include +#include +#include #include #include "../op/tensor/transform.h" @@ -117,36 +119,20 @@ class SimplifyTranspose : public DFPatternRewrite { Expr Callback(const Expr& pre, const Expr& post, const Map>& node_map) const override { - // Helper function to get the axes from call node attribute - auto get_axes_from_call = [](const Call trans_call, int ndim) { - std::vector attr_axes; - if (auto attr = trans_call->attrs.as()) { - if (attr->axes.defined()) { - for (int i = 0; i < ndim; ++i) { - int64_t axis = attr->axes[i]; - axis += (axis < 0) ? ndim : 0; - attr_axes.push_back(axis); - } - } else { - // Empty axes means reverse - for (int i = ndim - 1; i >= 0; --i) { - attr_axes.push_back(i); - } - } - } else if (auto attr = trans_call->attrs.as()) { - Layout src_layout(attr->src_layout); - Layout dst_layout(attr->dst_layout); - for (int i = 0; i < ndim; ++i) { - attr_axes.push_back(src_layout.IndexOf(dst_layout[i])); + auto x = node_map[x_][0]; + + Call trans_call = Downcast(post); + + // Try to fuse any rank changing layout transformations + if (auto layout_trans = FoldRankChangingLayoutTrans(x, trans_call)) { + if (auto attr = layout_trans.value()->attrs.as()) { + // Prune any trivial layout transformation + if (attr->src_layout == attr->dst_layout) { + return x; } - } else { - CHECK(false) << "Expected transpose or layout_transform, but got " - << Downcast(trans_call->op)->name; } - return std::move(attr_axes); - }; - - auto x = node_map[x_][0]; + return layout_trans.value(); + } // Initialize axes int ndim = Downcast(pre->checked_type())->shape.size(); @@ -157,10 +143,9 @@ class SimplifyTranspose : public DFPatternRewrite { // Collect axes changes from the matched pattern, including two consecutive transposes. std::vector> interm_axes; - Call trans_call = Downcast(post); - interm_axes.push_back(get_axes_from_call(trans_call, ndim)); + interm_axes.push_back(GetTransposeAxisOrder(trans_call, ndim)); trans_call = Downcast(trans_call->args[0]); - interm_axes.push_back(get_axes_from_call(trans_call, ndim)); + interm_axes.push_back(GetTransposeAxisOrder(trans_call, ndim)); // Calculate the final axes in reverse order (from root to output) auto it = interm_axes.rbegin(); @@ -190,6 +175,134 @@ class SimplifyTranspose : public DFPatternRewrite { return x; } + String PermuteLayout(const String& layout, std::vector axes_order) const { + std::string new_layout{}; + std::string old_layout{layout}; + ICHECK_EQ(axes_order.size(), layout.size()) + << "Number of axes must match the number of named axes in the layout to permute: length(" + << old_layout << ") != " << axes_order.size(); + std::stringstream order; + for (auto axis : axes_order) { + new_layout += old_layout[axis]; + order << axis << ", "; + } + DLOG(INFO) << "Using transpose axes order {" << order.str() + << "} to permute layout: " << old_layout << " to " << new_layout; + return new_layout; + } + + struct RankChangingLayoutDescriptor { + Layout src_layout; + Layout dst_layout; + // Either a rank changing layout transform or a transpose + Call other_transform; + }; + + std::unique_ptr GetRankChangeDescriptor(const Call& call) const { + std::unique_ptr desc{nullptr}; + if (auto attr = call->attrs.as()) { + if (attr->src_layout.length() != attr->dst_layout.length()) { + desc = std::make_unique(); + desc->src_layout = Layout(attr->src_layout); + desc->dst_layout = Layout(attr->dst_layout); + desc->other_transform = Downcast(call->args[0]); + } + } + if (auto attr = Downcast(call->args[0])->attrs.as()) { + if (attr->src_layout.length() != attr->dst_layout.length()) { + if (!desc) { + desc = std::make_unique(); + desc->src_layout = Layout(attr->src_layout); + desc->dst_layout = Layout(attr->dst_layout); + desc->other_transform = call; + } else { + ICHECK(desc->src_layout->name == attr->dst_layout) + << "Back-to-back layout transforms must have the same intermediate layout: " + << desc->src_layout->name << " != " << attr->dst_layout; + desc->src_layout = Layout(attr->src_layout); + } + } + } + return desc; + } + + /* + * \brief Fuse call and it's argument into a single layout_transform operator + * when either call or it's argument is a rang changing layout_transform, e.g., + * + * Simplify + * + * [N, H, W, C] -> Transpose -> [N, C, H, W] -> LayoutTrans -> [N, C, H, W, 4c] + * + * to, + * + * [N, H, W, C] -> LayoutTrans -> [N, C, H, W, 4c]. + * + * \param The input expression to the matched pattern + * \param The pattern root; the second of two consecutive Transpose/LayoutTransform ops + */ + Optional FoldRankChangingLayoutTrans(const Expr& data, const Call& call) const { + // Check to see if either the first or second call in matched pattern + // is a rank changing layout transform. If so, return a descriptor containing + // the layouts and any additional transpose or layout transform op. + auto desc = GetRankChangeDescriptor(call); + if (desc == nullptr) { + // No rank changing layout transform + return Optional{nullptr}; + } + + Optional output_layout_trans; + // Fuse a rank increasing layout transform and a preceeding transpose + if (desc->src_layout->axes.size() < desc->dst_layout->axes.size()) { + auto axes = GetTransposeAxisOrder(desc->other_transform, desc->src_layout->axes.size()); + // Calculate the reverse axis order and apply to the source layout + std::vector inverse(axes.size()); + for (size_t i = 0; i < axes.size(); i++) { + inverse[axes[i]] = i; + } + String new_layout = PermuteLayout(desc->src_layout->name, inverse); + output_layout_trans = MakeLayoutTransform(data, new_layout, desc->dst_layout->name); + // Fuse a rank descreasing layout transform followed by a transpose + } else if (desc->src_layout->axes.size() > desc->dst_layout->axes.size()) { + auto axes = GetTransposeAxisOrder(desc->other_transform, desc->dst_layout->axes.size()); + String new_layout = PermuteLayout(desc->dst_layout->name, axes); + output_layout_trans = MakeLayoutTransform(data, desc->src_layout->name, new_layout); + // Fuse two back-to-back layout transformations which change rank + } else if (desc->other_transform->attrs.as()) { + output_layout_trans = + MakeLayoutTransform(data, desc->src_layout->name, desc->dst_layout->name); + } + return Downcast(output_layout_trans); + } + + std::vector GetTransposeAxisOrder(const Call& call, int ndim) const { + std::vector attr_axes; + if (auto attr = call->attrs.as()) { + if (attr->axes.defined()) { + for (int i = 0; i < ndim; ++i) { + int64_t axis = attr->axes[i]; + axis += (axis < 0) ? ndim : 0; + attr_axes.push_back(axis); + } + } else { + // Empty axes means reverse + for (int i = ndim - 1; i >= 0; --i) { + attr_axes.push_back(i); + } + } + } else if (auto attr = call->attrs.as()) { + Layout src_layout(attr->src_layout); + Layout dst_layout(attr->dst_layout); + for (int i = 0; i < ndim; ++i) { + attr_axes.push_back(src_layout.IndexOf(dst_layout[i])); + } + } else { + CHECK(false) << "Expected transpose or layout_transform, but got " + << Downcast(call->op)->name; + } + return std::move(attr_axes); + } + private: /*! \brief Pattern input */ DFPattern x_; diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py index d1dffa34578b..9f11d3827064 100644 --- a/tests/python/relay/test_pass_simplify_expr.py +++ b/tests/python/relay/test_pass_simplify_expr.py @@ -106,10 +106,176 @@ def expected3(): y = relay.transpose(y, axes=[0, 2, 3, 1]) return relay.Function([x], y) + # Test a series of transpose and rank changing layout_transform + def before4(): + """ + Simplify transpose->layout_transform and its inverse. + + Input: + NHWC -> NCHW -> NCHW4c -> op -> NCHW4c -> NCHW -> NHWC + + Simplified: + NHWC -> NCHW4c -> op -> NCHW4c -> NHWC + """ + x = relay.var("x", shape=(1, 56, 56, 128), dtype="float32") + y = relay.transpose(x, axes=[0, 3, 1, 2]) + y = relay.layout_transform(y, "NCHW", "NCHW4c") + y = relay.nn.relu(y) + y = relay.layout_transform(y, "NCHW4c", "NCHW") + y = relay.transpose(y, axes=[0, 2, 3, 1]) + return relay.Function([x], y) + + def expected4(): + x = relay.var("x", shape=(1, 56, 56, 128), dtype="float32") # NHWC + y = relay.layout_transform(x, "NHWC", "NCHW4c") # To NCHW4c + y = relay.nn.relu(y) + y = relay.layout_transform(y, "NCHW4c", "NHWC") # To NHWC + return relay.Function([x], y) + + def before5(): + """ + Simplify layout_transform->layout_transform and its inverse. + + Input: + NHWC -> NCHW -> NCHW4c -> op -> NCHW4c -> NCHW -> NHWC + + Simplified: + NHWC -> NCHW4c -> op -> NCHW4c -> NHWC + """ + x = relay.var("x", shape=(1, 56, 56, 128), dtype="float32") # NHWC + y = relay.layout_transform(x, "NHWC", "NCHW") # To NCHW + y = relay.layout_transform(y, "NCHW", "NCHW4c") # To NCHW4c + y = relay.nn.relu(y) + y = relay.layout_transform(y, "NCHW4c", "NCHW") # To NCHW + y = relay.layout_transform(y, "NCHW", "NHWC") # To NHWC + return relay.Function([x], y) + + def expected5(): + x = relay.var("x", shape=(1, 56, 56, 128), dtype="float32") # NHWC + y = relay.layout_transform(x, "NHWC", "NCHW4c") # To NCHW4c + y = relay.nn.relu(y) + y = relay.layout_transform(y, "NCHW4c", "NHWC") # To NHWC + return relay.Function([x], y) + + def before6(): + """ + Remove trivial layout_transform->layout_transform. + + Input: + NCHW -> NHWC -> NCHW -> op + + Simplified: + NHWC -> op + """ + + x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32") + y = relay.layout_transform(x, "NCHW", "NHWC") + y = relay.layout_transform(y, "NHWC", "NCHW") + y = relay.nn.relu(y) + return relay.Function([x], y) + + def expected6(): + x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32") + y = relay.nn.relu(x) + return relay.Function([x], y) + + def before7(): + """ + Remove trivial layout_transform->layout_transform. + + Input: + NCHW4c -> NCHW8c -> NCHW4c -> op + + Simplified: + NCHW4c -> op + """ + x = relay.var("x", shape=(1, 32, 56, 56, 4), dtype="float32") + y = relay.layout_transform(x, "NCHW4c", "NCHW8c") + y = relay.layout_transform(y, "NCHW8c", "NCHW4c") + y = relay.nn.relu(y) + return relay.Function([x], y) + + def expected7(): + x = relay.var("x", shape=(1, 32, 56, 56, 4), dtype="float32") + y = relay.nn.relu(x) + return relay.Function([x], y) + + def before8(): + """ + Simplify layout_transform->layout_transform with rank contraction and expansion + + Input: + NCHW4c -> NCHW -> NCHW8c -> op + + Simplified: + NCHW4c -> NCHW8c -> op + """ + x = relay.var("x", shape=(1, 32, 56, 56, 4), dtype="float32") + y = relay.layout_transform(x, "NCHW4c", "NCHW") + y = relay.layout_transform(y, "NCHW", "NCHW8c") + y = relay.nn.relu(y) + return relay.Function([x], y) + + def expected8(): + x = relay.var("x", shape=(1, 32, 56, 56, 4), dtype="float32") + y = relay.layout_transform(x, "NCHW4c", "NCHW8c") + y = relay.nn.relu(y) + return relay.Function([x], y) + + def before9(): + """ + Remove trivial layout_transform->layout_transform. + + Input: + NCHW -> NCHW4c -> NCHW -> op + + Simplified: + NCHW -> op + """ + x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32") + y = relay.layout_transform(x, "NCHW", "NCHW4c") + y = relay.layout_transform(y, "NCHW4c", "NCHW") + y = relay.nn.relu(y) + return relay.Function([x], y) + + def expected9(): + x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32") + y = relay.nn.relu(x) + return relay.Function([x], y) + + def before10(): + """ + Simplify layout_transform->layout_transform without rank change to transpose. + + Input: + NCHW -> NHWC -> CHWN -> op + + Simplified: + NCHW -> CHWN -> op + """ + x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32") + y = relay.layout_transform(x, "NCHW", "NHWC") + y = relay.layout_transform(y, "NHWC", "CHWN") + y = relay.nn.relu(y) + return relay.Function([x], y) + + def expected10(): + x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32") + y = relay.transpose(x, axes=[1, 2, 3, 0]) + y = relay.nn.relu(y) + return relay.Function([x], y) + for before, expected in [ [before1(), expected1()], [before2(), expected2()], [before3(), expected3()], + [before4(), expected4()], + [before5(), expected5()], + [before6(), expected6()], + [before7(), expected7()], + [before8(), expected8()], + [before9(), expected9()], + [before10(), expected10()], ]: after = run_opt_pass(before, transform.SimplifyExpr()) expected = run_opt_pass(expected, transform.InferType())