From dad8cc767a12312a414b1b24cda76d51636a5493 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 29 Nov 2020 06:01:02 -0800 Subject: [PATCH] [AutoScheduler] Add layout rewrite pass in relay --- include/tvm/ir/transform.h | 7 + include/tvm/relay/attrs/nn.h | 27 +++ include/tvm/relay/attrs/transform.h | 14 ++ include/tvm/relay/transform.h | 14 ++ include/tvm/topi/transform.h | 68 ++++++++ python/tvm/auto_scheduler/__init__.py | 2 +- python/tvm/auto_scheduler/compute_dag.py | 30 ++++ .../tvm/auto_scheduler/relay_integration.py | 44 ++++- python/tvm/relay/op/_transform.py | 2 + python/tvm/relay/op/strategy/generic.py | 2 + python/tvm/relay/op/strategy/x86.py | 1 - python/tvm/te/tensor.py | 5 +- python/tvm/topi/nn/conv2d.py | 53 +++++- src/auto_scheduler/compute_dag.cc | 20 ++- src/ir/transform.cc | 54 +++--- src/relay/analysis/type_solver.cc | 1 + src/relay/backend/build_module.cc | 17 ++ src/relay/backend/compile_engine.cc | 26 ++- src/relay/backend/compile_engine.h | 9 + src/relay/backend/utils.h | 9 + src/relay/op/make_op.h | 2 + src/relay/op/tensor/transform.cc | 50 +++++- .../auto_scheduler_layout_rewrite.cc | 164 ++++++++++++++++++ .../auto_scheduler_layout_rewrite.h | 49 ++++++ .../test_auto_scheduler_layout_rewrite.py | 116 +++++++++++++ 25 files changed, 721 insertions(+), 65 deletions(-) create mode 100644 src/relay/transforms/auto_scheduler_layout_rewrite.cc create mode 100644 src/relay/transforms/auto_scheduler_layout_rewrite.h create mode 100644 tests/python/relay/test_auto_scheduler_layout_rewrite.py diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index d2931123073bc..2c5b6293e6c9d 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -197,6 +197,13 @@ class PassContext : public ObjectRef { */ TVM_DLL void Trace(const IRModule& module, const PassInfo& info, bool is_before) const; + /*! + * \brief Check if a pass is enabled. + * \param info The pass information. + * \return true if the pass is enabled. Otherwise, false. + */ + TVM_DLL bool PassEnabled(const PassInfo& info) const; + /*! * \brief Register a valid configuration option and its ValueType for validation. * diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index e697ac45bd125..ceb87ff9f5885 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -120,6 +120,7 @@ struct Conv2DAttrs : public tvm::AttrsNode { tvm::String data_layout; tvm::String kernel_layout; tvm::String out_layout; + std::string auto_scheduler_rewritten_layout; DataType out_dtype; TVM_DECLARE_ATTRS(Conv2DAttrs, "relay.attrs.Conv2DAttrs") { @@ -170,6 +171,9 @@ struct Conv2DAttrs : public tvm::AttrsNode { "Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "dimensions respectively. Default to be same as input layout."); + TVM_ATTR_FIELD(auto_scheduler_rewritten_layout) + .set_default("") + .describe("New kernel layout after auto-scheduler's layout rewrite."); // use 0 bits to indicate none. TVM_ATTR_FIELD(out_dtype) @@ -212,6 +216,7 @@ struct Conv2DWinogradAttrs : public tvm::AttrsNode { std::string data_layout; std::string kernel_layout; std::string out_layout; + std::string auto_scheduler_rewritten_layout; DataType out_dtype; TVM_DECLARE_ATTRS(Conv2DWinogradAttrs, "relay.attrs.Conv2DWinogradAttrs") { @@ -264,6 +269,9 @@ struct Conv2DWinogradAttrs : public tvm::AttrsNode { "Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "dimensions respectively. Default to be same as input layout."); + TVM_ATTR_FIELD(auto_scheduler_rewritten_layout) + .set_default("") + .describe("New kernel layout after auto-scheduler's layout rewrite."); // use 0 bits to indicate none. TVM_ATTR_FIELD(out_dtype) @@ -302,6 +310,7 @@ struct Conv3DAttrs : public tvm::AttrsNode { std::string data_layout; std::string kernel_layout; std::string out_layout; + std::string auto_scheduler_rewritten_layout; DataType out_dtype; TVM_DECLARE_ATTRS(Conv3DAttrs, "relay.attrs.Conv3DAttrs") { @@ -353,6 +362,9 @@ struct Conv3DAttrs : public tvm::AttrsNode { "Dimension ordering of output. Can be 'NCDHW', 'NDHWC', etc." "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" "dimensions respectively. Default to be same as input layout."); + TVM_ATTR_FIELD(auto_scheduler_rewritten_layout) + .set_default("") + .describe("New kernel layout after auto-scheduler's layout rewrite."); // use 0 bits to indicate none. TVM_ATTR_FIELD(out_dtype) @@ -923,10 +935,14 @@ struct AvgPool3DAttrs : public tvm::AttrsNode { /*! \brief Attributes for dense operator */ struct DenseAttrs : public tvm::AttrsNode { IndexExpr units; + std::string auto_scheduler_rewritten_layout; DataType out_dtype; TVM_DECLARE_ATTRS(DenseAttrs, "relay.attrs.DenseAttrs") { TVM_ATTR_FIELD(units).describe("Number of hidden units of the dense transformation."); + TVM_ATTR_FIELD(auto_scheduler_rewritten_layout) + .set_default("") + .describe("New kernel layout after auto-scheduler's layout rewrite."); // use 0 bits to indicate none. TVM_ATTR_FIELD(out_dtype) @@ -935,6 +951,17 @@ struct DenseAttrs : public tvm::AttrsNode { } }; +/*! \brief Attributes for batch matmul operator */ +struct BatchMatmulAttrs : public tvm::AttrsNode { + std::string auto_scheduler_rewritten_layout; + + TVM_DECLARE_ATTRS(BatchMatmulAttrs, "relay.attrs.BatchNormAttrs") { + TVM_ATTR_FIELD(auto_scheduler_rewritten_layout) + .set_default("") + .describe("New kernel layout after auto-scheduler's layout rewrite."); + } +}; + /*! \brief Attributes for sparse_dense operator */ struct SparseDenseAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(SparseDenseAttrs, "relay.attrs.SparseDenseAttrs") {} diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index a7830cf616479..607fa2c80e7eb 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -350,6 +350,20 @@ struct LayoutTransformAttrs : public tvm::AttrsNode { } }; +/*! \brief Attributes for AutoSchedulerLayoutTransform operator */ +struct AutoSchedulerLayoutTransformAttrs + : public tvm::AttrsNode { + std::string src_layout; + std::string dst_layout; + + TVM_DECLARE_ATTRS(AutoSchedulerLayoutTransformAttrs, + "relay.attrs.AutoSchedulerLayoutTransformAttrs") { + TVM_ATTR_FIELD(src_layout).describe("The source layout of the tensor. (e.g. 1N32C112H112W)"); + TVM_ATTR_FIELD(dst_layout) + .describe("The destination layout of the tensor. (e.g. 1N2C112H112W16c)"); + } +}; + /*! \brief Attributes for ShapeOf operator */ struct ShapeOfAttrs : public tvm::AttrsNode { DataType dtype; diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index a9a45b5f101a4..a2ffe5e853b3f 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -106,6 +106,14 @@ TVM_DLL Pass FoldConstant(); */ TVM_DLL Pass FuseOps(int fuse_opt_level = -1); +/*! + * \brief The inverse operation of FuseOps. It transforms a fused program returned by + * FuseOps into the program before FuseOps. (i.e., x == DefuseOps(FuseOps(x))) + * + * \return The pass. + */ +TVM_DLL Pass DefuseOps(); + /*! * \brief Rewrite the annotated program. * @@ -315,6 +323,12 @@ TVM_DLL Pass CanonicalizeOps(); */ TVM_DLL Pass AlterOpLayout(); +/*! + * \brief Do layout rewrite according to the tile structure created by auto-scheduler. + * \return The pass. + */ +TVM_DLL Pass AutoSchedulerLayoutRewrite(); + /*! * \brief Given a dest layout, this pass transforms the expr such that most of the ops input data * layout is changed to the dest layout. In ideal situation, there are only 2 layout transforms, one diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index c866dfb7f86b1..d200b1679bab1 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -1400,6 +1400,74 @@ inline Tensor layout_transform(const Tensor& src, const std::string& src_layout, name, tag); } +/*! \brief utility function for auto_scheduler_layout_transform */ +inline void parse_auto_scheduler_layout(const String& layout, Array* shape, + std::vector* axes) { + int32_t factor = 0; + std::string axis = ""; + for (char c : std::string(layout)) { + if (c >= 'A' && c <= 'z') { + axis += c; + if (factor != 0) { + shape->push_back(factor); + factor = 0; + } + } else if (c >= '0' && c <= '9') { + factor = factor * 10 + c - '0'; + if (!axis.empty()) { + axes->push_back(axis); + axis = ""; + } + } else { + LOG(FATAL) << "Invalid layout " << layout; + } + } + if (!axis.empty()) { + axes->push_back(axis); + } +} + +/*! + * \brief Transform the auto-scheduler generated layout according to + * \p src_layout and \p dst_layout + * \param src the source input. + * \param src_layout the source layout. + * \param dst_layout the destination layout. + * \param name output tensor name. + * \param tag output tensor tag. + * \return A tensor with shape in \p dst_layout + */ +inline Tensor auto_scheduler_layout_transform(const Tensor& src, const String& src_layout, + const String& dst_layout, + const String name = "T_auto_scheduler_layout_trans", + const String tag = kInjective) { + Array src_shape; + std::vector src_axes; + Array dst_shape; + std::vector dst_axes; + + parse_auto_scheduler_layout(src_layout, &src_shape, &src_axes); + parse_auto_scheduler_layout(dst_layout, &dst_shape, &dst_axes); + return compute( + dst_shape, + [&](const Array& dst_indices) { + Array dst_indices_expr(dst_indices.begin(), dst_indices.end()); + Array src_indices; + for (const std::string& src_axis : src_axes) { + PrimExpr src_index = 0; + CHECK_EQ(dst_indices_expr.size(), dst_axes.size()); + for (size_t i = 0; i < dst_axes.size(); ++i) { + if (dst_axes[i] == src_axis) { + src_index = src_index * dst_shape[i] + dst_indices_expr[i]; + } + } + src_indices.push_back(src_index); + } + return src(src_indices); + }, + name, tag); +} + /*! * \brief Get the shape of input tensor. * \param src the input tensor. diff --git a/python/tvm/auto_scheduler/__init__.py b/python/tvm/auto_scheduler/__init__.py index f0d076e75f027..be720afc50fe4 100644 --- a/python/tvm/auto_scheduler/__init__.py +++ b/python/tvm/auto_scheduler/__init__.py @@ -32,7 +32,7 @@ # Shortcut from .auto_schedule import TuningOptions, HardwareParams, create_task, auto_schedule -from .compute_dag import ComputeDAG +from .compute_dag import ComputeDAG, rewrite_compute_body from .cost_model import RandomModel, XGBModel from .dispatcher import DispatchContext, ApplyHistoryBest from .measure import ( diff --git a/python/tvm/auto_scheduler/compute_dag.py b/python/tvm/auto_scheduler/compute_dag.py index 3427709d819a1..806b37d697794 100755 --- a/python/tvm/auto_scheduler/compute_dag.py +++ b/python/tvm/auto_scheduler/compute_dag.py @@ -162,6 +162,22 @@ def infer_bound_from_state(self, state): updated_state.stage_id_map[k] = v return updated_state + def rewrite_layout_from_state(self, state): + """ + Rewrite the layout according to the transform steps in the history of a state + + Parameters + ---------- + state : Union[State, StateObject] + The state from which we get transform steps. + + Returns + ------- + updated_state : StateObject + """ + state_obj = state if isinstance(state, StateObject) else state.state_object + return _ffi_api.ComputeDAGRewriteLayoutFromState(self, state_obj) + def hash_key(self): """Return the hash key of this compute DAG. @@ -210,3 +226,17 @@ def __setstate__(self, state): self.compute = LoadJSON(state["compute"]) # pylint: disable=assignment-from-no-return self.sche = LoadJSON(state["sche"]) # pylint: disable=assignment-from-no-return self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, self.compute, self.sche) + + +def rewrite_compute_body(compute, placeholder, new_layout): + """Rewrite the body of a ComputeOp according to a new layout of a placeholder""" + body = [] + for b in compute.op.body: + body.append(_ffi_api.RewriteIndexForNewLayout(placeholder.op, new_layout, b)) + op_node = tvm.te._ffi_api.ComputeOp( + compute.op.name, compute.op.tag, compute.op.attrs, compute.op.axis, body + ) + + num = op_node.num_outputs + outputs = tuple(op_node.output(i) for i in range(num)) + return outputs[0] if num == 1 else outputs diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py index 6864bcce66e38..f487add7d34bd 100644 --- a/python/tvm/auto_scheduler/relay_integration.py +++ b/python/tvm/auto_scheduler/relay_integration.py @@ -23,6 +23,7 @@ """ import logging +import json import threading import tvm @@ -46,7 +47,11 @@ def call_all_topi_funcs(mod, params, target): old_autotvm_silent = autotvm.GLOBAL_SCOPE.silent autotvm.GLOBAL_SCOPE.silent = True - with transform.PassContext(opt_level=3, config={"relay.backend.use_auto_scheduler": True}): + with transform.PassContext( + opt_level=3, + config={"relay.backend.use_auto_scheduler": True}, + disabled_pass={"AutoSchedulerLayoutRewrite"}, + ): opt_mod, _ = relay.optimize(mod, target, params) grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target) grc.codegen(opt_mod["main"]) @@ -158,6 +163,18 @@ def add_workload_key(self, workload_key, ccache_key): self.wkl_key_to_ccache_key[workload_key] = ccache_key +@tvm._ffi.register_func("auto_scheduler.enter_layout_rewrite") +def _enter_layout_rewrite(*args): + env = TracingEnvironment(TracingMode.PREPARE_LAYOUT_REWRITE) + env.__enter__() + + +@tvm._ffi.register_func("auto_scheduler.exit_layout_rewrite") +def _exit_layout_rewrite(*args): + env = TracingEnvironment.current + env.__exit__(None, None, None) + + def traverse_to_get_io_tensors(outs): """Traverse from a list of output tensors to get both input and output tensors @@ -230,11 +247,13 @@ def auto_schedule_topi(outs, has_complex_op): key = register_workload_tensors(dag.hash_key(), io_tensors) # only enable layout rewrite for cpu backend - enable_layout_rewrite = "cpu" in tvm.target.Target.current().keys + target = tvm.target.Target.current() + enable_layout_rewrite = "cpu" in target.keys env = TracingEnvironment.current - if env is None: # in the final build mode - state = DispatchContext.current.query(tvm.target.Target.current(), key, has_complex_op, dag) + if env is None: + # in the final build mode + state = DispatchContext.current.query(target, key, has_complex_op, dag) if state is None: return None @@ -247,8 +266,21 @@ def auto_schedule_topi(outs, has_complex_op): env.add_workload_key(key, ccache_key) schedule = te.create_schedule([x.op for x in outs]) elif env.tracing_mode == TracingMode.PREPARE_LAYOUT_REWRITE: - # todo(merrymercy, minminsun): port layout rewrite - raise NotImplementedError + # in prepare_layout_rewrite mode + if enable_layout_rewrite and has_layout_free: + dispatch_ctx = DispatchContext.current + state = dispatch_ctx.query(target, key, has_complex_op, dag) + if state is None: + return te.create_schedule([x.op for x in outs]) + + # rewrite the layout and update the context + # for the new dag + dag = ComputeDAG(outs) + new_dag = dag.rewrite_layout_from_state(state) + new_key = json.dumps((new_dag.hash_key(),)) + if new_key != key: + dispatch_ctx.update(target, new_key, state) + return te.create_schedule([x.op for x in outs]) else: raise ValueError("Invalid tracing mode: " + env.tracing_mode) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 439d44b5790b9..de59a9c3657fa 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -79,6 +79,8 @@ def compute_strided_set(attrs, inputs, output_type): # layout_transform _reg.register_injective_schedule("layout_transform") _reg.register_pattern("layout_transform", OpPattern.INJECTIVE) +_reg.register_injective_schedule("auto_scheduler_layout_transform") +_reg.register_pattern("auto_scheduler_layout_transform", OpPattern.INJECTIVE) # argwhere @_reg.register_compute("argwhere") diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index e49135c4d1bf2..e8c3f41c7155d 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -179,6 +179,7 @@ def _compute_conv2d(attrs, inputs, out_type): data_layout = attrs.get_str("data_layout") out_layout = attrs.get_str("out_layout") out_dtype = attrs.out_dtype + auto_scheduler_rewritten_layout = attrs.get_str("auto_scheduler_rewritten_layout") out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype args = [inputs[0], inputs[1], strides, padding, dilation] if has_groups: @@ -188,6 +189,7 @@ def _compute_conv2d(attrs, inputs, out_type): if need_out_layout: args.append(out_layout) args.append(out_dtype) + args.append(auto_scheduler_rewritten_layout) return [topi_compute(*args)] return _compute_conv2d diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 3c5735b17aa5b..a705cc6963e81 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -117,7 +117,6 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): return conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target) elif layout == "NHWC": assert kernel_layout == "HWIO" - logger.warning("For x86 target, NCHW layout is recommended for conv2d.") strategy.add_implementation( wrap_compute_conv2d(topi.nn.conv2d_nhwc), wrap_topi_schedule(topi.x86.schedule_conv2d_nhwc), diff --git a/python/tvm/te/tensor.py b/python/tvm/te/tensor.py index 6294eab2cad9b..ed260be189a39 100644 --- a/python/tvm/te/tensor.py +++ b/python/tvm/te/tensor.py @@ -59,8 +59,9 @@ class Tensor(DataProducer, _expr.ExprOp): def __call__(self, *indices): ndim = self.ndim - if len(indices) != ndim: - raise ValueError("Need to provide %d index in tensor slice" % ndim) + # TODO(merrymercy): tmp hack for layout rewrite + # if len(indices) != ndim: + # raise ValueError("Need to provide %d index in tensor slice" % ndim) indices = convert_to_object(indices) args = [] for x in indices: diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py index 7c9cef6134391..2840ecabd40d6 100644 --- a/python/tvm/topi/nn/conv2d.py +++ b/python/tvm/topi/nn/conv2d.py @@ -20,7 +20,7 @@ from __future__ import absolute_import as _abs from collections import namedtuple import tvm -from tvm import te +from tvm import te, auto_scheduler from .pad import pad from .utils import get_pad_tuple @@ -331,7 +331,15 @@ def conv2d_hwcn(Input, Filter, stride, padding, dilation, out_dtype=None): return Output -def conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype="float32"): +def conv2d_nhwc( + Input, + Filter, + stride, + padding, + dilation, + out_dtype="float32", + auto_scheduler_rewritten_layout="", +): """Convolution operator in NHWC layout. Parameters @@ -371,8 +379,40 @@ def conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype="float32"): else: dilation_h, dilation_w = dilation + if auto_scheduler_rewritten_layout: + # infer shape for the rewritten layout + if len(Filter.shape) >= 10: + # For cpu tile structure SSRSRS + base = len(Filter.shape) - 10 + kernel_h = Filter.shape[2 + base] * Filter.shape[6 + base] + kernel_w = Filter.shape[3 + base] * Filter.shape[7 + base] + channel = Filter.shape[4 + base] * Filter.shape[8 + base] + num_filter = Filter.shape[5 + base] * Filter.shape[9 + base] + for i in range(base + 2): + num_filter *= Filter.shape[i] + elif len(Filter.shape) == 6: + # For cpu tile structure SRS + num_filter = Filter.shape[0] * Filter.shape[1] * Filter.shape[5] + kernel_h = Filter.shape[2] + kernel_w = Filter.shape[3] + channel = Filter.shape[4] + elif len(Filter.shape) == 5: + # For cpu tile structure SRS + num_filter = Filter.shape[0] * Filter.shape[4] + kernel_h = Filter.shape[1] + kernel_w = Filter.shape[2] + channel = Filter.shape[3] + elif len(Filter.shape) == 4: + num_filter, kernel_h, kernel_w, channel = Filter.shape + else: + raise ValueError( + "Don't know how to infer the layout for filter shape: %s. " + "You can add a new branch for it to fix this." % str(Filter) + ) + else: + kernel_h, kernel_w, channel, num_filter = Filter.shape + batch, in_height, in_width, in_channel = Input.shape - kernel_h, kernel_w, channel, num_filter = Filter.shape # compute the output shape dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 @@ -399,7 +439,14 @@ def conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype="float32"): ), name="Conv2dOutput", tag="conv2d_nhwc", + attrs={"layout_free_placeholders": [Filter]}, ) + + if auto_scheduler_rewritten_layout: + Output = auto_scheduler.rewrite_compute_body( + Output, Filter, auto_scheduler_rewritten_layout + ) + return Output diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index e57fc8c9c2d9e..f17861d233778 100755 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -42,6 +42,7 @@ #include #include "../arith/pattern_match.h" +#include "../relay/transforms/auto_scheduler_layout_rewrite.h" #include "search_policy/utils.h" #include "utils.h" @@ -813,8 +814,7 @@ std::string GetOrigLayout(std::set* placeholder_axis_names, const t ICHECK_EQ(placeholder_axis_names->size(), placeholder->shape.size()); std::string orig_layout = os.str(); os.str(""); - // TODO(minmin): uncomment this line for relay integration - // ::tvm::relay::KernelLayoutTransformer::global_orig_layouts_queue.push_back(orig_layout); + ::tvm::relay::AutoSchedulerLayoutRewriter::global_ori_layouts_queue.push_back(orig_layout); return orig_layout; } @@ -878,8 +878,7 @@ std::string GetNewLayout(const State& state, const int stage_id, const Stage& st } std::string new_layout = os.str(); os.str(""); - // TODO(minmin): uncomment this line for relay integration - // ::tvm::relay::KernelLayoutTransformer::global_new_layouts_queue.push_back(new_layout); + ::tvm::relay::AutoSchedulerLayoutRewriter::global_new_layouts_queue.push_back(new_layout); return new_layout; } @@ -1425,5 +1424,18 @@ TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAGInferBoundFromState") return dag.InferBound(state); }); +TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAGRewriteLayoutFromState") + .set_body_typed([](const ComputeDAG& dag, const State& state) { + Array* transform_steps = const_cast*>(&state->transform_steps); + return dag.RewriteLayout(transform_steps, LayoutRewriteOption::RewriteForPreTransformed); + }); + +TVM_REGISTER_GLOBAL("auto_scheduler.RewriteIndexForNewLayout") + .set_body_typed([](const te::Operation& placeholder_op, const std::string& new_layout, + const PrimExpr& body) { + IndexRewriter index_rewriter(placeholder_op, new_layout); + return index_rewriter.Rewrite(body); + }); + } // namespace auto_scheduler } // namespace tvm diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 3b774462565e9..f4516d5e57c56 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -74,6 +74,26 @@ PassContext PassContext::Current() { } } +// linearly scan the pass array to match pass_name +bool PassArrayContains(const Array& pass_array, const std::string& pass_name) { + for (auto x : pass_array) { + if (x == pass_name) return true; + } + return false; +} + +bool PassContext::PassEnabled(const PassInfo& info) const { + if (PassArrayContains(operator->()->disabled_pass, info->name)) { + return false; + } + + if (PassArrayContains(operator->()->required_pass, info->name)) { + return true; + } + + return operator->()->opt_level >= info->opt_level; +} + class PassConfigManager { public: void Register(std::string key, uint32_t value_type_index) { @@ -224,15 +244,6 @@ class SequentialNode : public PassNode { */ PassInfo Info() const override { return pass_info; } - /*! - * \brief Check if a pass is enabled. - * - * \param info The pass information. - * - * \return true if the pass is enabled. Otherwise, false. - */ - bool PassEnabled(const PassInfo& info) const; - /*! * \brief Resolve the pass dependency. It globs all required passes by * a given pass and executes them. @@ -344,29 +355,6 @@ void SequentialNode::ResolveDependency(const IRModule& mod) { << "\n"; } -// linearly scan the pass array to match pass_name -inline bool PassArrayContains(const Array& pass_array, - const std::string& pass_name) { - for (auto x : pass_array) { - if (x == pass_name) return true; - } - return false; -} - -bool SequentialNode::PassEnabled(const PassInfo& info) const { - PassContext ctx = PassContext::Current(); - - if (PassArrayContains(ctx->disabled_pass, info->name)) { - return false; - } - - if (PassArrayContains(ctx->required_pass, info->name)) { - return true; - } - - return ctx->opt_level >= info->opt_level; -} - Pass GetPass(const String& pass_name) { using tvm::runtime::Registry; const runtime::PackedFunc* f = nullptr; @@ -387,7 +375,7 @@ IRModule SequentialNode::operator()(IRModule mod, const PassContext& pass_ctx) c for (const Pass& pass : passes) { ICHECK(pass.defined()) << "Found undefined pass for optimization."; const PassInfo& pass_info = pass->Info(); - if (!PassEnabled(pass_info)) continue; + if (!pass_ctx.PassEnabled(pass_info)) continue; // resolve dependencies for (const auto& it : pass_info->required) { mod = GetPass(it)(std::move(mod), pass_ctx); diff --git a/src/relay/analysis/type_solver.cc b/src/relay/analysis/type_solver.cc index 8f14b557dc54b..b4b5ec820b738 100644 --- a/src/relay/analysis/type_solver.cc +++ b/src/relay/analysis/type_solver.cc @@ -230,6 +230,7 @@ class TypeSolver::Unifier : public TypeFunctor { return Type(nullptr); } + tt1 = tt2; // TODO(merrymercy): tmp hack for layout rewrite in auto-scheduler. tvm::Array shape; if (tt1->shape.size() != tt2->shape.size()) { this->solver_->diag_ctx_.Emit(Diagnostic::Error(this->span) diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 82ac1c57018e8..a0828d1cac6cf 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -338,7 +338,24 @@ class RelayBuildModule : public runtime::ModuleNode { // Fuse the operations if it is needed. relay_module = transform::FuseOps()(relay_module); + + // Do layout rewrite for auto-scheduler. + if (backend::IsAutoSchedulerEnabled() && targets.size() == 1) { + const auto& target = (*targets.begin()).second; + Pass major_pass = transform::AutoSchedulerLayoutRewrite(); + + if (target->kind->device_type == kDLCPU && pass_ctx.PassEnabled(major_pass->Info())) { + With tctx(target); + relay_module = major_pass(relay_module); + // Defuse ops to fold constants, then fuse them again + relay_module = transform::DefuseOps()(relay_module); + relay_module = transform::FoldConstant()(relay_module); + relay_module = transform::FuseOps()(relay_module); + } + } + relay_module = transform::InferType()(relay_module); + // Inline the functions that have been lifted by the module scope. // // TODO(@zhiics) Note that we need to be careful about the subgraphs with diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 1559d7edf35ff..98d9136629531 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -101,9 +101,7 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> explicit ScheduleGetter(Target target) : target_(target), device_copy_op_(Op::Get("device_copy")) { // Whether to use auto_scheduler schedule. - use_auto_scheduler_ = transform::PassContext::Current() - ->GetConfig("relay.backend.use_auto_scheduler", Bool(false)) - .value(); + use_auto_scheduler_ = backend::IsAutoSchedulerEnabled(); } CachedFunc Create(const Function& prim_func) { @@ -322,6 +320,17 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> const Op& device_copy_op_; }; +/*! + * \brief Create schedule for target. + * \param source_func The primitive function to be lowered. + * \param target The target we want to create schedule for. + * \return Pair of schedule and cache. + * The funcs field in cache is not yet populated. + */ +CachedFunc CreateSchedule(const Function& source_func, const Target& target) { + return ScheduleGetter(target).Create(source_func); +} + // Creates shape function from functor. class MakeShapeFunc : public backend::MemoizedExprTranslator> { public: @@ -680,17 +689,6 @@ class CompileEngineImpl : public CompileEngineNode { */ CCacheKey GetCurrentCCacheKey() { return cur_ccache_key_; } - /*! - * \brief Create schedule for target. - * \param source_func The primitive function to be lowered. - * \param target The target we want to create schedule for. - * \return Pair of schedule and cache. - * The funcs field in cache is not yet populated. - */ - CachedFunc CreateSchedule(const Function& source_func, const Target& target) { - return ScheduleGetter(target).Create(source_func); - } - private: // implement lowered func CCacheValue LowerInternal(const CCacheKey& key) { diff --git a/src/relay/backend/compile_engine.h b/src/relay/backend/compile_engine.h index 55822917b6b71..d7628e7a5bdf1 100644 --- a/src/relay/backend/compile_engine.h +++ b/src/relay/backend/compile_engine.h @@ -241,6 +241,15 @@ class CompileEngine : public ObjectRef { TVM_DLL static CompileEngine& Global(); }; +/*! + * \brief Create schedule for target. + * \param source_func The primitive function to be lowered. + * \param target The target we want to create schedule for. + * \return Pair of schedule and cache. + * The funcs field in cache is not yet populated. + */ +CachedFunc CreateSchedule(const Function& source_func, const Target& target); + /*! * \brief Check if the type is dynamic. * \param ty The type to be checked. diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 4426642e8e18b..36016cf3658a3 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -295,6 +295,15 @@ inline std::string GetExtSymbol(const Function& func) { return std::string(name_node.value()); } +/*! + * \brief Return whether the auto scheduler is enabled in the pass context. + */ +inline bool IsAutoSchedulerEnabled() { + return transform::PassContext::Current() + ->GetConfig("relay.backend.use_auto_scheduler", Bool(false)) + .value(); +} + } // namespace backend } // namespace relay } // namespace tvm diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h index 34bff0f5b8582..d2fb6aa2b9c3c 100644 --- a/src/relay/op/make_op.h +++ b/src/relay/op/make_op.h @@ -52,6 +52,8 @@ Expr MakeFull(Expr fill_value, Array shape, DataType dtype); Expr MakeLayoutTransform(Expr data, String src_layout, String dst_layout); +Expr MakeAutoSchedulerLayoutTransform(Expr data, String src_layout, String dst_layout); + Expr MakeOnes(Array shape, DataType dtype); Expr MakePad(Expr data, Array> pad_width, double pad_value, String pad_mode); diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index d1f2f267c5802..ee154100d78f4 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2738,7 +2738,55 @@ the input array by output[n, c, h, w, C] = data[n, C*16+c, h, w] .set_support_level(5) .set_attr("FTVMCompute", LayoutTransformCompute); -/* relay._contrib_reverse_reshape */ +// relay.auto_scheduler_layout_transform +TVM_REGISTER_NODE_TYPE(AutoSchedulerLayoutTransformAttrs); + +Array AutoSchedulerLayoutTransformCompute(const Attrs& attrs, + const Array& inputs, + const Type& out_type) { + const auto* param = attrs.as(); + CHECK(param != nullptr); + return Array{ + topi::auto_scheduler_layout_transform(inputs[0], param->src_layout, param->dst_layout)}; +} + +bool AutoSchedulerLayoutTransformRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + const auto* data = types[0].as(); + CHECK(data != nullptr); + const AutoSchedulerLayoutTransformAttrs* params = attrs.as(); + + Array dst_shape; + std::vector dst_axes; + + topi::parse_auto_scheduler_layout(params->dst_layout, &dst_shape, &dst_axes); + + reporter->Assign(types[1], TensorType(dst_shape, data->dtype)); + return true; +} + +Expr MakeAutoSchedulerLayoutTransform(Expr data, String src_layout, String dst_layout) { + auto attrs = make_object(); + attrs->src_layout = std::move(src_layout); + attrs->dst_layout = std::move(dst_layout); + static const Op& op = Op::Get("auto_scheduler_layout_transform"); + return Call(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.auto_scheduler_layout_transform") + .set_body_typed(MakeAutoSchedulerLayoutTransform); + +RELAY_REGISTER_OP("auto_scheduler_layout_transform") + .describe(R"code(Transform the input kernel layout. +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .add_type_rel("auto_scheduler_layout_transform", AutoSchedulerLayoutTransformRel) + .set_support_level(5) + .set_attr("FTVMCompute", AutoSchedulerLayoutTransformCompute); + +// relay._contrib_reverse_reshape Expr MakeReverseReshape(Expr data, Array newshape) { auto attrs = make_object(); attrs->newshape = std::move(newshape); diff --git a/src/relay/transforms/auto_scheduler_layout_rewrite.cc b/src/relay/transforms/auto_scheduler_layout_rewrite.cc new file mode 100644 index 0000000000000..17f6ecd93a40a --- /dev/null +++ b/src/relay/transforms/auto_scheduler_layout_rewrite.cc @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file auto_scheduler_layout_rewrite.h + * \brief Rewrite the layout of "layout free" tensors (e.g., the weight tensors in + * conv2d and dense layers) according to the tile structure generated by the auto-scheduler. + */ + +#include "auto_scheduler_layout_rewrite.h" + +#include +#include +#include +#include + +#include +#include +#include + +#include "../backend/compile_engine.h" +#include "pattern_utils.h" + +namespace tvm { +namespace relay { + +// Two global variables for receiving layout information from python +std::deque AutoSchedulerLayoutRewriter::global_ori_layouts_queue; +std::deque AutoSchedulerLayoutRewriter::global_new_layouts_queue; + +// Copy an Attrs but with a new auto_scheduler_rewritten_layout filed. +template +Attrs CopyAttrsWithNewLayout(const T* ptr, const std::string& layout) { + auto n = make_object(*ptr); + n->auto_scheduler_rewritten_layout = layout; + return Attrs(n); +} + +// Mutate ops in a function +class FuncMutator : public ExprMutator { + public: + FuncMutator(const std::deque& ori_layouts_queue, + const std::deque& new_layouts_queue) + : ExprMutator(), + ori_layouts_queue_(ori_layouts_queue), + new_layouts_queue_(new_layouts_queue) {} + + Expr VisitExpr_(const CallNode* n) { + auto new_n = ExprMutator::VisitExpr_(n); + + const auto* call = new_n.as(); + if (call && call->op.as() && + (std::find(target_ops_.begin(), target_ops_.end(), n->op.as()->name) != + target_ops_.end()) && + !ori_layouts_queue_.empty() && !new_layouts_queue_.empty()) { + // Pop a new layout from the queue + const std::string ori_layout = ori_layouts_queue_.front(); + const std::string new_layout = new_layouts_queue_.front(); + ori_layouts_queue_.pop_front(); + new_layouts_queue_.pop_front(); + + // Insert a new op to do layout transform. (This will be simplified by FoldConstant later). + Expr updated_kernel = MakeAutoSchedulerLayoutTransform(call->args[1], ori_layout, new_layout); + Array updated_args = {call->args[0], updated_kernel}; + + // Update the attrs + Attrs updated_attrs; + if (auto pattr = call->attrs.as()) { + updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout); + } else if (auto pattr = call->attrs.as()) { + updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout); + } else if (auto pattr = call->attrs.as()) { + updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout); + } else if (auto pattr = call->attrs.as()) { + updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout); + } else if (auto pattr = call->attrs.as()) { + updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout); + } + new_n = Call(call->op, updated_args, updated_attrs); + } + return new_n; + } + + private: + std::deque ori_layouts_queue_; + std::deque new_layouts_queue_; + + std::vector target_ops_{"nn.contrib_conv2d_winograd_without_weight_transform", + "nn.conv2d", "nn.conv3d", "nn.dense", "nn.batch_matmul"}; +}; + +Expr AutoSchedulerLayoutRewriter::VisitExpr_(const CallNode* n) { + auto new_n = ExprMutator::VisitExpr_(n); + const auto* call = new_n.as(); + if (call) { + const auto* func = call->op.as(); + if (func) { + global_ori_layouts_queue.clear(); + global_new_layouts_queue.clear(); + + // Use ScheduleGetter to call python lower functions + // This is used to get the layout transform information + // The layout transformation will be recorded to global_ori_layout_queue + // in ComputeDAG::RewriteLayout + auto f = runtime::Registry::Get("auto_scheduler.enter_layout_rewrite"); + CHECK(f) << "Could not find auto_scheduler.enter_layout_rewrite function."; + (*f)(); + + CreateSchedule(GetRef(func), Target::Current()); + + f = runtime::Registry::Get("auto_scheduler.exit_layout_rewrite"); + CHECK(f) << "Could not find ansor.exit_layout_rewrite function."; + (*f)(); + + // Mutate the called function + if (!global_ori_layouts_queue.empty() && !global_new_layouts_queue.empty()) { + auto ret = FuncMutator(global_ori_layouts_queue, global_new_layouts_queue).VisitExpr(new_n); + return ret; + } + } + } + + return new_n; +} + +Expr AutoSchedulerLayoutRewrite(const Expr& expr) { + // Do a post-order DSF to mutate the layout of + // all "layout free" placeholders for the auto-scheduler. + return AutoSchedulerLayoutRewriter().Mutate(expr); +} + +namespace transform { + +Pass AutoSchedulerLayoutRewrite() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast(relay::AutoSchedulerLayoutRewrite(f)); + }; + return CreateFunctionPass(pass_func, 3, "AutoSchedulerLayoutRewrite", {"InferType"}); +} + +TVM_REGISTER_GLOBAL("relay._transform.AutoSchedulerLayoutRewrite") + .set_body_typed(AutoSchedulerLayoutRewrite); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/src/relay/transforms/auto_scheduler_layout_rewrite.h b/src/relay/transforms/auto_scheduler_layout_rewrite.h new file mode 100644 index 0000000000000..d8d4b1ec5302a --- /dev/null +++ b/src/relay/transforms/auto_scheduler_layout_rewrite.h @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file auto_scheduler_layout_rewrite.h + * \brief Rewrite the layout of "layout free" tensors (e.g., the weight tensors in + * conv2d and dense layers) according to the tile structure generated by the auto-scheduler. + */ + +#ifndef TVM_RELAY_AUTO_SCHEDULER_LAYOUT_REWRITE_H_ +#define TVM_RELAY_AUTO_SCHEDULER_LAYOUT_REWRITE_H_ + +#include + +#include +#include + +namespace tvm { +namespace relay { + +class AutoSchedulerLayoutRewriter : public ExprMutator { + public: + Expr VisitExpr_(const CallNode* n) final; + + // Two global variables for receiving layout information from python + static std::deque global_ori_layouts_queue; + static std::deque global_new_layouts_queue; +}; + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_AUTO_SCHEDULER_LAYOUT_REWRITE_H_ diff --git a/tests/python/relay/test_auto_scheduler_layout_rewrite.py b/tests/python/relay/test_auto_scheduler_layout_rewrite.py new file mode 100644 index 0000000000000..89366a0b9fdab --- /dev/null +++ b/tests/python/relay/test_auto_scheduler_layout_rewrite.py @@ -0,0 +1,116 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test layout rewrite support for whole neural networks""" +import numpy as np + +import tvm +from tvm import relay, auto_scheduler +from tvm.contrib import graph_runtime +import tvm.testing + + +def get_np_array(var, dtype): + return np.random.randn(*[int(x) for x in var.type_annotation.shape]).astype(dtype) + + +def get_relay_conv2d( + outc=128, + inc=64, + height=14, + width=14, + kh=3, + kw=3, + batch=1, + pad=0, + stride=1, + dilation=1, + layout="NHWC", +): + dtype = "float32" + if layout == "NHWC": + kernel_layout = "HWIO" + d = relay.var("data", shape=(batch, height, width, inc), dtype=dtype) + w = relay.var("weight", shape=(kh, kw, inc, outc), dtype=dtype) + elif layout == "NCHW": + kernel_layout = "OIHW" + d = relay.var("data", shape=(batch, inc, height, width), dtype=dtype) + w = relay.var("weight", shape=(outc, inc, kh, kw), dtype=dtype) + + y = relay.nn.conv2d( + d, + w, + padding=pad, + kernel_size=(kh, kw), + strides=(stride, stride), + dilation=(dilation, dilation), + channels=outc, + groups=1, + data_layout=layout, + kernel_layout=kernel_layout, + ) + mod = tvm.IRModule() + mod["main"] = relay.Function([d, w], y) + data, weight = get_np_array(d, dtype), get_np_array(w, dtype) + return mod, data, weight + + +def tune_and_check(mod, data, weight): + # Extract tasks from a relay program + target = tvm.target.Target("llvm") + tasks, task_weights = auto_scheduler.extract_tasks(mod, target=target, params={}) + + # Tune workloads + log_file = "test_layout_rewrite.json" + tuner = auto_scheduler.TaskScheduler(tasks, task_weights) + tune_option = auto_scheduler.TuningOptions( + num_measure_trials=1, + num_measures_per_round=1, + builder=auto_scheduler.LocalBuilder(timeout=60), + measure_callbacks=[auto_scheduler.RecordToFile(log_file)], + ) + tuner.tune(tune_option, search_policy="sketch.random") + + # Compile and run + def compile_and_run(disabled_pass={}): + with auto_scheduler.ApplyHistoryBest(log_file): + with tvm.transform.PassContext( + opt_level=3, + config={"relay.backend.use_auto_scheduler": True}, + disabled_pass=disabled_pass, + ): + lib = relay.build(mod, target=target, params={"weight": weight}) + + ctx = tvm.cpu() + module = graph_runtime.GraphModule(lib["default"](ctx)) + module.set_input("data", data) + module.run() + + return module.get_output(0).asnumpy() + + # Check correctness + actual_output = compile_and_run() + expected_output = compile_and_run(disabled_pass={"AutoSchedulerLayoutRewrite"}) + tvm.testing.assert_allclose(actual_output, expected_output, rtol=1e-3) + + +def test_conv2d(): + mod, data, weight = get_relay_conv2d(kh=1, kw=1) + tune_and_check(mod, data, weight) + + +if __name__ == "__main__": + test_conv2d()