From 54b9e51fbb19ff72258ed9a01d01f6d78166e32d Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 29 Nov 2020 08:11:35 -0800 Subject: [PATCH 01/12] [AutoScheduler] Add layout rewrite pass in relay --- include/tvm/ir/transform.h | 7 + include/tvm/relay/attrs/nn.h | 8 + include/tvm/relay/attrs/transform.h | 14 ++ include/tvm/relay/transform.h | 14 ++ include/tvm/topi/transform.h | 68 ++++++++ python/tvm/auto_scheduler/__init__.py | 2 +- python/tvm/auto_scheduler/compute_dag.py | 17 ++ .../tvm/auto_scheduler/relay_integration.py | 103 +++++++++++- python/tvm/relay/op/_transform.py | 2 + python/tvm/relay/op/strategy/generic.py | 2 + python/tvm/relay/op/strategy/x86.py | 1 - python/tvm/te/tensor.py | 2 +- python/tvm/topi/nn/conv2d.py | 40 ++++- src/auto_scheduler/compute_dag.cc | 20 ++- src/ir/transform.cc | 54 +++--- src/relay/backend/build_module.cc | 17 ++ src/relay/backend/compile_engine.cc | 26 ++- src/relay/backend/compile_engine.h | 9 + src/relay/backend/utils.h | 9 + src/relay/op/make_op.h | 2 + src/relay/op/nn/convolution.h | 7 + src/relay/op/tensor/transform.cc | 50 +++++- .../auto_scheduler_layout_rewrite.cc | 155 ++++++++++++++++++ .../auto_scheduler_layout_rewrite.h | 49 ++++++ .../test_auto_scheduler_layout_rewrite.py | 122 ++++++++++++++ 25 files changed, 735 insertions(+), 65 deletions(-) create mode 100644 src/relay/transforms/auto_scheduler_layout_rewrite.cc create mode 100644 src/relay/transforms/auto_scheduler_layout_rewrite.h create mode 100644 tests/python/relay/test_auto_scheduler_layout_rewrite.py diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index d2931123073b..2c5b6293e6c9 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -197,6 +197,13 @@ class PassContext : public ObjectRef { */ TVM_DLL void Trace(const IRModule& module, const PassInfo& info, bool is_before) const; + /*! + * \brief Check if a pass is enabled. + * \param info The pass information. + * \return true if the pass is enabled. Otherwise, false. + */ + TVM_DLL bool PassEnabled(const PassInfo& info) const; + /*! * \brief Register a valid configuration option and its ValueType for validation. * diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index e697ac45bd12..278e98097e84 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -120,6 +120,7 @@ struct Conv2DAttrs : public tvm::AttrsNode { tvm::String data_layout; tvm::String kernel_layout; tvm::String out_layout; + std::string auto_scheduler_rewritten_layout; DataType out_dtype; TVM_DECLARE_ATTRS(Conv2DAttrs, "relay.attrs.Conv2DAttrs") { @@ -170,6 +171,9 @@ struct Conv2DAttrs : public tvm::AttrsNode { "Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "dimensions respectively. Default to be same as input layout."); + TVM_ATTR_FIELD(auto_scheduler_rewritten_layout) + .set_default("") + .describe("New kernel layout after auto-scheduler's layout rewrite."); // use 0 bits to indicate none. TVM_ATTR_FIELD(out_dtype) @@ -212,6 +216,7 @@ struct Conv2DWinogradAttrs : public tvm::AttrsNode { std::string data_layout; std::string kernel_layout; std::string out_layout; + std::string auto_scheduler_rewritten_layout; DataType out_dtype; TVM_DECLARE_ATTRS(Conv2DWinogradAttrs, "relay.attrs.Conv2DWinogradAttrs") { @@ -264,6 +269,9 @@ struct Conv2DWinogradAttrs : public tvm::AttrsNode { "Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "dimensions respectively. Default to be same as input layout."); + TVM_ATTR_FIELD(auto_scheduler_rewritten_layout) + .set_default("") + .describe("New kernel layout after auto-scheduler's layout rewrite."); // use 0 bits to indicate none. TVM_ATTR_FIELD(out_dtype) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index a7830cf61647..607fa2c80e7e 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -350,6 +350,20 @@ struct LayoutTransformAttrs : public tvm::AttrsNode { } }; +/*! \brief Attributes for AutoSchedulerLayoutTransform operator */ +struct AutoSchedulerLayoutTransformAttrs + : public tvm::AttrsNode { + std::string src_layout; + std::string dst_layout; + + TVM_DECLARE_ATTRS(AutoSchedulerLayoutTransformAttrs, + "relay.attrs.AutoSchedulerLayoutTransformAttrs") { + TVM_ATTR_FIELD(src_layout).describe("The source layout of the tensor. (e.g. 1N32C112H112W)"); + TVM_ATTR_FIELD(dst_layout) + .describe("The destination layout of the tensor. (e.g. 1N2C112H112W16c)"); + } +}; + /*! \brief Attributes for ShapeOf operator */ struct ShapeOfAttrs : public tvm::AttrsNode { DataType dtype; diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index a9a45b5f101a..a2ffe5e853b3 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -106,6 +106,14 @@ TVM_DLL Pass FoldConstant(); */ TVM_DLL Pass FuseOps(int fuse_opt_level = -1); +/*! + * \brief The inverse operation of FuseOps. It transforms a fused program returned by + * FuseOps into the program before FuseOps. (i.e., x == DefuseOps(FuseOps(x))) + * + * \return The pass. + */ +TVM_DLL Pass DefuseOps(); + /*! * \brief Rewrite the annotated program. * @@ -315,6 +323,12 @@ TVM_DLL Pass CanonicalizeOps(); */ TVM_DLL Pass AlterOpLayout(); +/*! + * \brief Do layout rewrite according to the tile structure created by auto-scheduler. + * \return The pass. + */ +TVM_DLL Pass AutoSchedulerLayoutRewrite(); + /*! * \brief Given a dest layout, this pass transforms the expr such that most of the ops input data * layout is changed to the dest layout. In ideal situation, there are only 2 layout transforms, one diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index c866dfb7f86b..d200b1679bab 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -1400,6 +1400,74 @@ inline Tensor layout_transform(const Tensor& src, const std::string& src_layout, name, tag); } +/*! \brief utility function for auto_scheduler_layout_transform */ +inline void parse_auto_scheduler_layout(const String& layout, Array* shape, + std::vector* axes) { + int32_t factor = 0; + std::string axis = ""; + for (char c : std::string(layout)) { + if (c >= 'A' && c <= 'z') { + axis += c; + if (factor != 0) { + shape->push_back(factor); + factor = 0; + } + } else if (c >= '0' && c <= '9') { + factor = factor * 10 + c - '0'; + if (!axis.empty()) { + axes->push_back(axis); + axis = ""; + } + } else { + LOG(FATAL) << "Invalid layout " << layout; + } + } + if (!axis.empty()) { + axes->push_back(axis); + } +} + +/*! + * \brief Transform the auto-scheduler generated layout according to + * \p src_layout and \p dst_layout + * \param src the source input. + * \param src_layout the source layout. + * \param dst_layout the destination layout. + * \param name output tensor name. + * \param tag output tensor tag. + * \return A tensor with shape in \p dst_layout + */ +inline Tensor auto_scheduler_layout_transform(const Tensor& src, const String& src_layout, + const String& dst_layout, + const String name = "T_auto_scheduler_layout_trans", + const String tag = kInjective) { + Array src_shape; + std::vector src_axes; + Array dst_shape; + std::vector dst_axes; + + parse_auto_scheduler_layout(src_layout, &src_shape, &src_axes); + parse_auto_scheduler_layout(dst_layout, &dst_shape, &dst_axes); + return compute( + dst_shape, + [&](const Array& dst_indices) { + Array dst_indices_expr(dst_indices.begin(), dst_indices.end()); + Array src_indices; + for (const std::string& src_axis : src_axes) { + PrimExpr src_index = 0; + CHECK_EQ(dst_indices_expr.size(), dst_axes.size()); + for (size_t i = 0; i < dst_axes.size(); ++i) { + if (dst_axes[i] == src_axis) { + src_index = src_index * dst_shape[i] + dst_indices_expr[i]; + } + } + src_indices.push_back(src_index); + } + return src(src_indices); + }, + name, tag); +} + /*! * \brief Get the shape of input tensor. * \param src the input tensor. diff --git a/python/tvm/auto_scheduler/__init__.py b/python/tvm/auto_scheduler/__init__.py index f0d076e75f02..5bf2335ec7cf 100644 --- a/python/tvm/auto_scheduler/__init__.py +++ b/python/tvm/auto_scheduler/__init__.py @@ -44,7 +44,7 @@ LocalRPCMeasureContext, ) from .measure_record import RecordToFile, RecordReader, load_best, load_records, save_records -from .relay_integration import extract_tasks +from .relay_integration import extract_tasks, remove_index_check, rewrite_compute_body from .search_task import SearchTask from .search_policy import EmptyPolicy, SketchPolicy, PreloadMeasuredStates from .task_scheduler import TaskScheduler diff --git a/python/tvm/auto_scheduler/compute_dag.py b/python/tvm/auto_scheduler/compute_dag.py index 3427709d819a..c1a195f3c8fe 100755 --- a/python/tvm/auto_scheduler/compute_dag.py +++ b/python/tvm/auto_scheduler/compute_dag.py @@ -162,6 +162,23 @@ def infer_bound_from_state(self, state): updated_state.stage_id_map[k] = v return updated_state + def rewrite_layout_from_state(self, state): + """ + Rewrite the layout according to the transform steps in the history of a state + + Parameters + ---------- + state : Union[State, StateObject] + The state from which we get transform steps. + + Returns + ------- + updated_dag : ComputeDAG + The compute dag with rewritten layout. + """ + state_obj = state if isinstance(state, StateObject) else state.state_object + return _ffi_api.ComputeDAGRewriteLayoutFromState(self, state_obj) + def hash_key(self): """Return the hash key of this compute DAG. diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py index 6864bcce66e3..9cd7b43d8067 100644 --- a/python/tvm/auto_scheduler/relay_integration.py +++ b/python/tvm/auto_scheduler/relay_integration.py @@ -23,11 +23,15 @@ """ import logging +import json import threading import tvm from tvm import autotvm, te, transform -from tvm.te.tensor import ComputeOp, PlaceholderOp +from tvm.runtime import convert_to_object +from tvm.te.tensor import ComputeOp, PlaceholderOp, Tensor +from tvm.tir import expr as _expr +from . import _ffi_api from .compute_dag import ComputeDAG from .dispatcher import DispatchContext from .search_task import SearchTask @@ -46,7 +50,11 @@ def call_all_topi_funcs(mod, params, target): old_autotvm_silent = autotvm.GLOBAL_SCOPE.silent autotvm.GLOBAL_SCOPE.silent = True - with transform.PassContext(opt_level=3, config={"relay.backend.use_auto_scheduler": True}): + with transform.PassContext( + opt_level=3, + config={"relay.backend.use_auto_scheduler": True}, + disabled_pass={"AutoSchedulerLayoutRewrite"}, + ): opt_mod, _ = relay.optimize(mod, target, params) grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target) grc.codegen(opt_mod["main"]) @@ -158,6 +166,20 @@ def add_workload_key(self, workload_key, ccache_key): self.wkl_key_to_ccache_key[workload_key] = ccache_key +@tvm._ffi.register_func("auto_scheduler.enter_layout_rewrite") +def enter_layout_rewrite(): + """Enter layout rewrite tracing environment""" + env = TracingEnvironment(TracingMode.PREPARE_LAYOUT_REWRITE) + env.__enter__() + + +@tvm._ffi.register_func("auto_scheduler.exit_layout_rewrite") +def exit_layout_rewrite(): + """Exit layout rewrite tracing environment""" + env = TracingEnvironment.current + env.__exit__(None, None, None) + + def traverse_to_get_io_tensors(outs): """Traverse from a list of output tensors to get both input and output tensors @@ -230,11 +252,13 @@ def auto_schedule_topi(outs, has_complex_op): key = register_workload_tensors(dag.hash_key(), io_tensors) # only enable layout rewrite for cpu backend - enable_layout_rewrite = "cpu" in tvm.target.Target.current().keys + target = tvm.target.Target.current() + enable_layout_rewrite = "cpu" in target.keys env = TracingEnvironment.current - if env is None: # in the final build mode - state = DispatchContext.current.query(tvm.target.Target.current(), key, has_complex_op, dag) + if env is None: + # in the final build mode + state = DispatchContext.current.query(target, key, has_complex_op, dag) if state is None: return None @@ -247,9 +271,74 @@ def auto_schedule_topi(outs, has_complex_op): env.add_workload_key(key, ccache_key) schedule = te.create_schedule([x.op for x in outs]) elif env.tracing_mode == TracingMode.PREPARE_LAYOUT_REWRITE: - # todo(merrymercy, minminsun): port layout rewrite - raise NotImplementedError + # in prepare_layout_rewrite mode + if enable_layout_rewrite and has_layout_free: + dispatch_ctx = DispatchContext.current + state = dispatch_ctx.query(target, key, has_complex_op, dag) + if state is None: + return te.create_schedule([x.op for x in outs]) + + # rewrite the layout and update the context for the new dag + dag = ComputeDAG(outs) + new_dag = dag.rewrite_layout_from_state(state) + new_key = json.dumps((new_dag.hash_key(),)) + if new_key != key: + dispatch_ctx.update(target, new_key, state) + return te.create_schedule([x.op for x in outs]) else: raise ValueError("Invalid tracing mode: " + env.tracing_mode) return schedule + + +def tensor_no_check_call(self, *indices): + """An indexing function without any check. + This is the same as `tvm.te.Tensor::__call__` except that the safety + check is removed. + """ + indices = convert_to_object(indices) + args = [] + for x in indices: + if isinstance(x, _expr.PrimExpr): + args.append(x) + elif isinstance(x, _expr.IterVar): + args.append(x.var) + else: + raise ValueError("The indices must be expression") + + return _expr.ProducerLoad(self, args) + + +def remove_index_check(tensor): + """Remove the safety check in the indexing function for a tensor. + This is done by monkey patching its indexing function. + After removing the check, we are allowed to create a + temporary wrong IR and fix it later in other places. + + Parameters + ---------- + tensor: Tensor + The tensor to remove index check. + """ + # Monkey patch the indexing function + tensor.__call__ = tensor_no_check_call.__get__(tensor, Tensor) + + +def rewrite_compute_body(compute_tensor, new_layout): + """Rewrite the body of a ComputeOp according to a new layout of a placeholder""" + op = compute_tensor.op + + # Get layout free placeholders + layout_free_placeholders = op.attrs["layout_free_placeholders"] + assert len(layout_free_placeholders) == 1 + placeholder_op = layout_free_placeholders[0].op + + # Rewrite the index expression in body + body = [] + for b in op.body: + body.append(_ffi_api.RewriteIndexForNewLayout(placeholder_op, new_layout, b)) + op_node = tvm.te._ffi_api.ComputeOp(op.name, op.tag, op.attrs, op.axis, body) + + num = op_node.num_outputs + outputs = tuple(op_node.output(i) for i in range(num)) + return outputs[0] if num == 1 else outputs diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 439d44b5790b..de59a9c3657f 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -79,6 +79,8 @@ def compute_strided_set(attrs, inputs, output_type): # layout_transform _reg.register_injective_schedule("layout_transform") _reg.register_pattern("layout_transform", OpPattern.INJECTIVE) +_reg.register_injective_schedule("auto_scheduler_layout_transform") +_reg.register_pattern("auto_scheduler_layout_transform", OpPattern.INJECTIVE) # argwhere @_reg.register_compute("argwhere") diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index e49135c4d1bf..e8c3f41c7155 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -179,6 +179,7 @@ def _compute_conv2d(attrs, inputs, out_type): data_layout = attrs.get_str("data_layout") out_layout = attrs.get_str("out_layout") out_dtype = attrs.out_dtype + auto_scheduler_rewritten_layout = attrs.get_str("auto_scheduler_rewritten_layout") out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype args = [inputs[0], inputs[1], strides, padding, dilation] if has_groups: @@ -188,6 +189,7 @@ def _compute_conv2d(attrs, inputs, out_type): if need_out_layout: args.append(out_layout) args.append(out_dtype) + args.append(auto_scheduler_rewritten_layout) return [topi_compute(*args)] return _compute_conv2d diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 3c5735b17aa5..a705cc6963e8 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -117,7 +117,6 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): return conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target) elif layout == "NHWC": assert kernel_layout == "HWIO" - logger.warning("For x86 target, NCHW layout is recommended for conv2d.") strategy.add_implementation( wrap_compute_conv2d(topi.nn.conv2d_nhwc), wrap_topi_schedule(topi.x86.schedule_conv2d_nhwc), diff --git a/python/tvm/te/tensor.py b/python/tvm/te/tensor.py index 6294eab2cad9..bdf39544759b 100644 --- a/python/tvm/te/tensor.py +++ b/python/tvm/te/tensor.py @@ -40,7 +40,7 @@ def __getitem__(self, indices): def asobject(self): """Convert slice to object.""" - return self.tensor(*self.indices) + return self.tensor.__call__(*self.indices) @property def dtype(self): diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py index 7c9cef613439..492b62b3e21d 100644 --- a/python/tvm/topi/nn/conv2d.py +++ b/python/tvm/topi/nn/conv2d.py @@ -20,7 +20,7 @@ from __future__ import absolute_import as _abs from collections import namedtuple import tvm -from tvm import te +from tvm import te, auto_scheduler from .pad import pad from .utils import get_pad_tuple @@ -331,7 +331,15 @@ def conv2d_hwcn(Input, Filter, stride, padding, dilation, out_dtype=None): return Output -def conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype="float32"): +def conv2d_nhwc( + Input, + Filter, + stride, + padding, + dilation, + out_dtype="float32", + auto_scheduler_rewritten_layout="", +): """Convolution operator in NHWC layout. Parameters @@ -371,8 +379,29 @@ def conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype="float32"): else: dilation_h, dilation_w = dilation + if auto_scheduler_rewritten_layout: + # Infer shape for the rewritten layout + if len(Filter.shape) >= 10: + # For cpu tile structure SSRSRS + base = len(Filter.shape) - 10 + kernel_h = Filter.shape[2 + base] * Filter.shape[6 + base] + kernel_w = Filter.shape[3 + base] * Filter.shape[7 + base] + channel = Filter.shape[4 + base] * Filter.shape[8 + base] + num_filter = Filter.shape[5 + base] * Filter.shape[9 + base] + for i in range(base + 2): + num_filter *= Filter.shape[i] + elif len(Filter.shape) == 4: + num_filter, kernel_h, kernel_w, channel = Filter.shape + else: + raise ValueError( + "Don't know how to infer the layout for filter shape: %s. " + "Please add a new branch to handle this case." % str(Filter) + ) + auto_scheduler.remove_index_check(Filter) + else: + kernel_h, kernel_w, channel, num_filter = Filter.shape + batch, in_height, in_width, in_channel = Input.shape - kernel_h, kernel_w, channel, num_filter = Filter.shape # compute the output shape dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 @@ -399,7 +428,12 @@ def conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype="float32"): ), name="Conv2dOutput", tag="conv2d_nhwc", + attrs={"layout_free_placeholders": [Filter]}, ) + + if auto_scheduler_rewritten_layout: + Output = auto_scheduler.rewrite_compute_body(Output, auto_scheduler_rewritten_layout) + return Output diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index e57fc8c9c2d9..f17861d23377 100755 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -42,6 +42,7 @@ #include #include "../arith/pattern_match.h" +#include "../relay/transforms/auto_scheduler_layout_rewrite.h" #include "search_policy/utils.h" #include "utils.h" @@ -813,8 +814,7 @@ std::string GetOrigLayout(std::set* placeholder_axis_names, const t ICHECK_EQ(placeholder_axis_names->size(), placeholder->shape.size()); std::string orig_layout = os.str(); os.str(""); - // TODO(minmin): uncomment this line for relay integration - // ::tvm::relay::KernelLayoutTransformer::global_orig_layouts_queue.push_back(orig_layout); + ::tvm::relay::AutoSchedulerLayoutRewriter::global_ori_layouts_queue.push_back(orig_layout); return orig_layout; } @@ -878,8 +878,7 @@ std::string GetNewLayout(const State& state, const int stage_id, const Stage& st } std::string new_layout = os.str(); os.str(""); - // TODO(minmin): uncomment this line for relay integration - // ::tvm::relay::KernelLayoutTransformer::global_new_layouts_queue.push_back(new_layout); + ::tvm::relay::AutoSchedulerLayoutRewriter::global_new_layouts_queue.push_back(new_layout); return new_layout; } @@ -1425,5 +1424,18 @@ TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAGInferBoundFromState") return dag.InferBound(state); }); +TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAGRewriteLayoutFromState") + .set_body_typed([](const ComputeDAG& dag, const State& state) { + Array* transform_steps = const_cast*>(&state->transform_steps); + return dag.RewriteLayout(transform_steps, LayoutRewriteOption::RewriteForPreTransformed); + }); + +TVM_REGISTER_GLOBAL("auto_scheduler.RewriteIndexForNewLayout") + .set_body_typed([](const te::Operation& placeholder_op, const std::string& new_layout, + const PrimExpr& body) { + IndexRewriter index_rewriter(placeholder_op, new_layout); + return index_rewriter.Rewrite(body); + }); + } // namespace auto_scheduler } // namespace tvm diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 3b774462565e..f4516d5e57c5 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -74,6 +74,26 @@ PassContext PassContext::Current() { } } +// linearly scan the pass array to match pass_name +bool PassArrayContains(const Array& pass_array, const std::string& pass_name) { + for (auto x : pass_array) { + if (x == pass_name) return true; + } + return false; +} + +bool PassContext::PassEnabled(const PassInfo& info) const { + if (PassArrayContains(operator->()->disabled_pass, info->name)) { + return false; + } + + if (PassArrayContains(operator->()->required_pass, info->name)) { + return true; + } + + return operator->()->opt_level >= info->opt_level; +} + class PassConfigManager { public: void Register(std::string key, uint32_t value_type_index) { @@ -224,15 +244,6 @@ class SequentialNode : public PassNode { */ PassInfo Info() const override { return pass_info; } - /*! - * \brief Check if a pass is enabled. - * - * \param info The pass information. - * - * \return true if the pass is enabled. Otherwise, false. - */ - bool PassEnabled(const PassInfo& info) const; - /*! * \brief Resolve the pass dependency. It globs all required passes by * a given pass and executes them. @@ -344,29 +355,6 @@ void SequentialNode::ResolveDependency(const IRModule& mod) { << "\n"; } -// linearly scan the pass array to match pass_name -inline bool PassArrayContains(const Array& pass_array, - const std::string& pass_name) { - for (auto x : pass_array) { - if (x == pass_name) return true; - } - return false; -} - -bool SequentialNode::PassEnabled(const PassInfo& info) const { - PassContext ctx = PassContext::Current(); - - if (PassArrayContains(ctx->disabled_pass, info->name)) { - return false; - } - - if (PassArrayContains(ctx->required_pass, info->name)) { - return true; - } - - return ctx->opt_level >= info->opt_level; -} - Pass GetPass(const String& pass_name) { using tvm::runtime::Registry; const runtime::PackedFunc* f = nullptr; @@ -387,7 +375,7 @@ IRModule SequentialNode::operator()(IRModule mod, const PassContext& pass_ctx) c for (const Pass& pass : passes) { ICHECK(pass.defined()) << "Found undefined pass for optimization."; const PassInfo& pass_info = pass->Info(); - if (!PassEnabled(pass_info)) continue; + if (!pass_ctx.PassEnabled(pass_info)) continue; // resolve dependencies for (const auto& it : pass_info->required) { mod = GetPass(it)(std::move(mod), pass_ctx); diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 82ac1c57018e..a0828d1cac6c 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -338,7 +338,24 @@ class RelayBuildModule : public runtime::ModuleNode { // Fuse the operations if it is needed. relay_module = transform::FuseOps()(relay_module); + + // Do layout rewrite for auto-scheduler. + if (backend::IsAutoSchedulerEnabled() && targets.size() == 1) { + const auto& target = (*targets.begin()).second; + Pass major_pass = transform::AutoSchedulerLayoutRewrite(); + + if (target->kind->device_type == kDLCPU && pass_ctx.PassEnabled(major_pass->Info())) { + With tctx(target); + relay_module = major_pass(relay_module); + // Defuse ops to fold constants, then fuse them again + relay_module = transform::DefuseOps()(relay_module); + relay_module = transform::FoldConstant()(relay_module); + relay_module = transform::FuseOps()(relay_module); + } + } + relay_module = transform::InferType()(relay_module); + // Inline the functions that have been lifted by the module scope. // // TODO(@zhiics) Note that we need to be careful about the subgraphs with diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 1559d7edf35f..98d913662953 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -101,9 +101,7 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> explicit ScheduleGetter(Target target) : target_(target), device_copy_op_(Op::Get("device_copy")) { // Whether to use auto_scheduler schedule. - use_auto_scheduler_ = transform::PassContext::Current() - ->GetConfig("relay.backend.use_auto_scheduler", Bool(false)) - .value(); + use_auto_scheduler_ = backend::IsAutoSchedulerEnabled(); } CachedFunc Create(const Function& prim_func) { @@ -322,6 +320,17 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> const Op& device_copy_op_; }; +/*! + * \brief Create schedule for target. + * \param source_func The primitive function to be lowered. + * \param target The target we want to create schedule for. + * \return Pair of schedule and cache. + * The funcs field in cache is not yet populated. + */ +CachedFunc CreateSchedule(const Function& source_func, const Target& target) { + return ScheduleGetter(target).Create(source_func); +} + // Creates shape function from functor. class MakeShapeFunc : public backend::MemoizedExprTranslator> { public: @@ -680,17 +689,6 @@ class CompileEngineImpl : public CompileEngineNode { */ CCacheKey GetCurrentCCacheKey() { return cur_ccache_key_; } - /*! - * \brief Create schedule for target. - * \param source_func The primitive function to be lowered. - * \param target The target we want to create schedule for. - * \return Pair of schedule and cache. - * The funcs field in cache is not yet populated. - */ - CachedFunc CreateSchedule(const Function& source_func, const Target& target) { - return ScheduleGetter(target).Create(source_func); - } - private: // implement lowered func CCacheValue LowerInternal(const CCacheKey& key) { diff --git a/src/relay/backend/compile_engine.h b/src/relay/backend/compile_engine.h index 55822917b6b7..d7628e7a5bdf 100644 --- a/src/relay/backend/compile_engine.h +++ b/src/relay/backend/compile_engine.h @@ -241,6 +241,15 @@ class CompileEngine : public ObjectRef { TVM_DLL static CompileEngine& Global(); }; +/*! + * \brief Create schedule for target. + * \param source_func The primitive function to be lowered. + * \param target The target we want to create schedule for. + * \return Pair of schedule and cache. + * The funcs field in cache is not yet populated. + */ +CachedFunc CreateSchedule(const Function& source_func, const Target& target); + /*! * \brief Check if the type is dynamic. * \param ty The type to be checked. diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 4426642e8e18..36016cf3658a 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -295,6 +295,15 @@ inline std::string GetExtSymbol(const Function& func) { return std::string(name_node.value()); } +/*! + * \brief Return whether the auto scheduler is enabled in the pass context. + */ +inline bool IsAutoSchedulerEnabled() { + return transform::PassContext::Current() + ->GetConfig("relay.backend.use_auto_scheduler", Bool(false)) + .value(); +} + } // namespace backend } // namespace relay } // namespace tvm diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h index 34bff0f5b858..d2fb6aa2b9c3 100644 --- a/src/relay/op/make_op.h +++ b/src/relay/op/make_op.h @@ -52,6 +52,8 @@ Expr MakeFull(Expr fill_value, Array shape, DataType dtype); Expr MakeLayoutTransform(Expr data, String src_layout, String dst_layout); +Expr MakeAutoSchedulerLayoutTransform(Expr data, String src_layout, String dst_layout); + Expr MakeOnes(Array shape, DataType dtype); Expr MakePad(Expr data, Array> pad_width, double pad_value, String pad_mode); diff --git a/src/relay/op/nn/convolution.h b/src/relay/op/nn/convolution.h index f0112227153d..76939fc9cca9 100644 --- a/src/relay/op/nn/convolution.h +++ b/src/relay/op/nn/convolution.h @@ -143,6 +143,13 @@ bool Conv2DRel(const Array& types, int num_inputs, const Attrs& attrs, const Layout in_layout(param->data_layout); const Layout kernel_layout(param->kernel_layout); + // If the layout is rewritten by auto-scheduler, + // we just forcly apply the layout provided by auto-scheduler and + // skip the normal inference logic. + if (param->auto_scheduler_rewritten_layout.size() > 0) { + return false; + } + const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW); if (!trans_in_layout.defined()) { reporter->GetDiagCtx().Emit( diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index d1f2f267c580..ee154100d78f 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2738,7 +2738,55 @@ the input array by output[n, c, h, w, C] = data[n, C*16+c, h, w] .set_support_level(5) .set_attr("FTVMCompute", LayoutTransformCompute); -/* relay._contrib_reverse_reshape */ +// relay.auto_scheduler_layout_transform +TVM_REGISTER_NODE_TYPE(AutoSchedulerLayoutTransformAttrs); + +Array AutoSchedulerLayoutTransformCompute(const Attrs& attrs, + const Array& inputs, + const Type& out_type) { + const auto* param = attrs.as(); + CHECK(param != nullptr); + return Array{ + topi::auto_scheduler_layout_transform(inputs[0], param->src_layout, param->dst_layout)}; +} + +bool AutoSchedulerLayoutTransformRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + const auto* data = types[0].as(); + CHECK(data != nullptr); + const AutoSchedulerLayoutTransformAttrs* params = attrs.as(); + + Array dst_shape; + std::vector dst_axes; + + topi::parse_auto_scheduler_layout(params->dst_layout, &dst_shape, &dst_axes); + + reporter->Assign(types[1], TensorType(dst_shape, data->dtype)); + return true; +} + +Expr MakeAutoSchedulerLayoutTransform(Expr data, String src_layout, String dst_layout) { + auto attrs = make_object(); + attrs->src_layout = std::move(src_layout); + attrs->dst_layout = std::move(dst_layout); + static const Op& op = Op::Get("auto_scheduler_layout_transform"); + return Call(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.auto_scheduler_layout_transform") + .set_body_typed(MakeAutoSchedulerLayoutTransform); + +RELAY_REGISTER_OP("auto_scheduler_layout_transform") + .describe(R"code(Transform the input kernel layout. +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .add_type_rel("auto_scheduler_layout_transform", AutoSchedulerLayoutTransformRel) + .set_support_level(5) + .set_attr("FTVMCompute", AutoSchedulerLayoutTransformCompute); + +// relay._contrib_reverse_reshape Expr MakeReverseReshape(Expr data, Array newshape) { auto attrs = make_object(); attrs->newshape = std::move(newshape); diff --git a/src/relay/transforms/auto_scheduler_layout_rewrite.cc b/src/relay/transforms/auto_scheduler_layout_rewrite.cc new file mode 100644 index 000000000000..2d95f9ee4b2c --- /dev/null +++ b/src/relay/transforms/auto_scheduler_layout_rewrite.cc @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file auto_scheduler_layout_rewrite.h + * \brief Rewrite the layout of "layout free" tensors (e.g., the weight tensors in + * conv2d and dense layers) according to the tile structure generated by the auto-scheduler. + */ + +#include "auto_scheduler_layout_rewrite.h" + +#include +#include +#include +#include + +#include +#include +#include + +#include "../backend/compile_engine.h" +#include "pattern_utils.h" + +namespace tvm { +namespace relay { + +// Two global variables for receiving layout information from python +std::deque AutoSchedulerLayoutRewriter::global_ori_layouts_queue; +std::deque AutoSchedulerLayoutRewriter::global_new_layouts_queue; + +// Copy an Attrs but with a new auto_scheduler_rewritten_layout filed. +template +Attrs CopyAttrsWithNewLayout(const T* ptr, const std::string& layout) { + auto n = make_object(*ptr); + n->auto_scheduler_rewritten_layout = layout; + return Attrs(n); +} + +// Mutate ops in a function +class FuncMutator : public ExprMutator { + public: + FuncMutator(const std::deque& ori_layouts_queue, + const std::deque& new_layouts_queue) + : ExprMutator(), + ori_layouts_queue_(ori_layouts_queue), + new_layouts_queue_(new_layouts_queue) {} + + Expr VisitExpr_(const CallNode* n) { + auto new_n = ExprMutator::VisitExpr_(n); + + const auto* call = new_n.as(); + if (call && call->op.as() && + (std::find(target_ops_.begin(), target_ops_.end(), n->op.as()->name) != + target_ops_.end()) && + !ori_layouts_queue_.empty() && !new_layouts_queue_.empty()) { + // Pop a new layout from the queue + const std::string ori_layout = ori_layouts_queue_.front(); + const std::string new_layout = new_layouts_queue_.front(); + ori_layouts_queue_.pop_front(); + new_layouts_queue_.pop_front(); + + // Insert a new op to do layout transform. (This will be simplified by FoldConstant later). + Expr updated_kernel = MakeAutoSchedulerLayoutTransform(call->args[1], ori_layout, new_layout); + Array updated_args = {call->args[0], updated_kernel}; + + // Update the attrs + Attrs updated_attrs; + if (auto pattr = call->attrs.as()) { + updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout); + } else if (auto pattr = call->attrs.as()) { + updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout); + } + new_n = Call(call->op, updated_args, updated_attrs); + } + return new_n; + } + + private: + std::deque ori_layouts_queue_; + std::deque new_layouts_queue_; + + std::vector target_ops_{"nn.contrib_conv2d_winograd_without_weight_transform", + "nn.conv2d"}; +}; + +Expr AutoSchedulerLayoutRewriter::VisitExpr_(const CallNode* n) { + auto new_n = ExprMutator::VisitExpr_(n); + + if (const auto* call = new_n.as()) { + if (const auto* func = call->op.as()) { + global_ori_layouts_queue.clear(); + global_new_layouts_queue.clear(); + + // Use ScheduleGetter to call python lower functions. + // This is used to get the layout transform information. + // The layout transformation will be recorded to global_ori_layout_queue + // and global_new_layouts_queue in ComputeDAG::RewriteLayout. + auto f = runtime::Registry::Get("auto_scheduler.enter_layout_rewrite"); + CHECK(f) << "Could not find auto_scheduler.enter_layout_rewrite function."; + (*f)(); + + CreateSchedule(GetRef(func), Target::Current()); + + f = runtime::Registry::Get("auto_scheduler.exit_layout_rewrite"); + CHECK(f) << "Could not find ansor.exit_layout_rewrite function."; + (*f)(); + + // Mutate the called function + if (!global_ori_layouts_queue.empty() && !global_new_layouts_queue.empty()) { + auto ret = FuncMutator(global_ori_layouts_queue, global_new_layouts_queue).VisitExpr(new_n); + return ret; + } + } + } + + return new_n; +} + +Expr AutoSchedulerLayoutRewrite(const Expr& expr) { + return AutoSchedulerLayoutRewriter().Mutate(expr); +} + +namespace transform { + +Pass AutoSchedulerLayoutRewrite() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast(relay::AutoSchedulerLayoutRewrite(f)); + }; + return CreateFunctionPass(pass_func, 3, "AutoSchedulerLayoutRewrite", {"InferType"}); +} + +TVM_REGISTER_GLOBAL("relay._transform.AutoSchedulerLayoutRewrite") + .set_body_typed(AutoSchedulerLayoutRewrite); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/src/relay/transforms/auto_scheduler_layout_rewrite.h b/src/relay/transforms/auto_scheduler_layout_rewrite.h new file mode 100644 index 000000000000..d0d89db42e68 --- /dev/null +++ b/src/relay/transforms/auto_scheduler_layout_rewrite.h @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file auto_scheduler_layout_rewrite.h + * \brief Rewrite the layout of "layout free" tensors (e.g., the weight tensors in + * conv2d and dense layers) according to the tile structure generated by the auto-scheduler. + */ + +#ifndef TVM_RELAY_TRANSFORMS_AUTO_SCHEDULER_LAYOUT_REWRITE_H_ +#define TVM_RELAY_TRANSFORMS_AUTO_SCHEDULER_LAYOUT_REWRITE_H_ + +#include + +#include +#include + +namespace tvm { +namespace relay { + +class AutoSchedulerLayoutRewriter : public ExprMutator { + public: + Expr VisitExpr_(const CallNode* n) final; + + // Two global variables for receiving layout information from python + static std::deque global_ori_layouts_queue; + static std::deque global_new_layouts_queue; +}; + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_TRANSFORMS_AUTO_SCHEDULER_LAYOUT_REWRITE_H_ diff --git a/tests/python/relay/test_auto_scheduler_layout_rewrite.py b/tests/python/relay/test_auto_scheduler_layout_rewrite.py new file mode 100644 index 000000000000..c92a580a7c9b --- /dev/null +++ b/tests/python/relay/test_auto_scheduler_layout_rewrite.py @@ -0,0 +1,122 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test layout rewrite support for whole neural networks""" +import tempfile + +import numpy as np + +import tvm +from tvm import relay, auto_scheduler +from tvm.contrib import graph_runtime +import tvm.testing + + +def get_np_array(var, dtype): + return np.random.randn(*[int(x) for x in var.type_annotation.shape]).astype(dtype) + + +def get_relay_conv2d( + outc=128, + inc=64, + height=14, + width=14, + kh=3, + kw=3, + batch=1, + pad=0, + stride=1, + dilation=1, + layout="NHWC", +): + dtype = "float32" + if layout == "NHWC": + kernel_layout = "HWIO" + d = relay.var("data", shape=(batch, height, width, inc), dtype=dtype) + w = relay.var("weight", shape=(kh, kw, inc, outc), dtype=dtype) + elif layout == "NCHW": + kernel_layout = "OIHW" + d = relay.var("data", shape=(batch, inc, height, width), dtype=dtype) + w = relay.var("weight", shape=(outc, inc, kh, kw), dtype=dtype) + + y = relay.nn.conv2d( + d, + w, + padding=pad, + kernel_size=(kh, kw), + strides=(stride, stride), + dilation=(dilation, dilation), + channels=outc, + groups=1, + data_layout=layout, + kernel_layout=kernel_layout, + ) + mod = tvm.IRModule() + mod["main"] = relay.Function([d, w], y) + data, weight = get_np_array(d, dtype), get_np_array(w, dtype) + return mod, data, weight + + +def tune_and_check(mod, data, weight): + # Extract tasks from a relay program + target = tvm.target.Target("llvm") + tasks, task_weights = auto_scheduler.extract_tasks(mod, target=target, params={}) + + with tempfile.NamedTemporaryFile() as fp: + log_file = fp.name + # log_file = "test_layout_rewrite.json" + + # Tune tasks + tuner = auto_scheduler.TaskScheduler(tasks, task_weights) + tune_option = auto_scheduler.TuningOptions( + num_measure_trials=1, + num_measures_per_round=1, + builder=auto_scheduler.LocalBuilder(timeout=60), + measure_callbacks=[auto_scheduler.RecordToFile(log_file)], + ) + tuner.tune(tune_option, search_policy="sketch.random") + + # Compile and run + def compile_and_run(disabled_pass={}): + with auto_scheduler.ApplyHistoryBest(log_file): + with tvm.transform.PassContext( + opt_level=3, + config={"relay.backend.use_auto_scheduler": True}, + disabled_pass=disabled_pass, + ): + lib = relay.build(mod, target=target, params={"weight": weight}) + + ctx = tvm.cpu() + module = graph_runtime.GraphModule(lib["default"](ctx)) + module.set_input("data", data) + module.run() + + return module.get_output(0).asnumpy() + + # Check correctness + actual_output = compile_and_run() + expected_output = compile_and_run(disabled_pass={"AutoSchedulerLayoutRewrite"}) + + tvm.testing.assert_allclose(actual_output, expected_output, rtol=1e-4) + + +def test_conv2d(): + mod, data, weight = get_relay_conv2d(kh=1, kw=1) + tune_and_check(mod, data, weight) + + +if __name__ == "__main__": + test_conv2d() From 082dcc9ac532d62ec5810a20b19e34051074e38f Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 29 Nov 2020 09:46:38 -0800 Subject: [PATCH 02/12] fix --- include/tvm/relay/attrs/nn.h | 4 ---- python/tvm/auto_scheduler/measure.py | 4 +++- python/tvm/relay/op/strategy/generic.py | 9 +++++++-- python/tvm/relay/op/strategy/x86.py | 2 +- src/relay/op/nn/convolution.h | 19 ++++++++++--------- .../auto_scheduler_layout_rewrite.cc | 5 +---- 6 files changed, 22 insertions(+), 21 deletions(-) diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 278e98097e84..67b719ef1dcd 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -216,7 +216,6 @@ struct Conv2DWinogradAttrs : public tvm::AttrsNode { std::string data_layout; std::string kernel_layout; std::string out_layout; - std::string auto_scheduler_rewritten_layout; DataType out_dtype; TVM_DECLARE_ATTRS(Conv2DWinogradAttrs, "relay.attrs.Conv2DWinogradAttrs") { @@ -269,9 +268,6 @@ struct Conv2DWinogradAttrs : public tvm::AttrsNode { "Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "dimensions respectively. Default to be same as input layout."); - TVM_ATTR_FIELD(auto_scheduler_rewritten_layout) - .set_default("") - .describe("New kernel layout after auto-scheduler's layout rewrite."); // use 0 bits to indicate none. TVM_ATTR_FIELD(out_dtype) diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index 117cd4f8bc71..b9d7148be784 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -544,7 +544,9 @@ def _timed_func(inp_serialized, build_func, verbose): args = [] try: - sch, args = task.compute_dag.apply_steps_from_state(inp.state, layout_rewrite=True) + sch, args = task.compute_dag.apply_steps_from_state( + inp.state, layout_rewrite=ComputeDAG.RewriteForPreTransformed + ) # pylint: disable=broad-except except Exception: error_no = MeasureErrorNo.INSTANTIATION_ERROR diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index e8c3f41c7155..f746926880cf 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -168,7 +168,11 @@ def schedule_bitpack(attrs, outs, target): # conv2d def wrap_compute_conv2d( - topi_compute, need_data_layout=False, need_out_layout=False, has_groups=False + topi_compute, + need_data_layout=False, + need_out_layout=False, + has_groups=False, + need_auto_scheduler_layout=False, ): """Wrap conv2d topi compute""" @@ -189,7 +193,8 @@ def _compute_conv2d(attrs, inputs, out_type): if need_out_layout: args.append(out_layout) args.append(out_dtype) - args.append(auto_scheduler_rewritten_layout) + if need_auto_scheduler_layout: + args.append(auto_scheduler_rewritten_layout) return [topi_compute(*args)] return _compute_conv2d diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index a705cc6963e8..f85ee70b485c 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -118,7 +118,7 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): elif layout == "NHWC": assert kernel_layout == "HWIO" strategy.add_implementation( - wrap_compute_conv2d(topi.nn.conv2d_nhwc), + wrap_compute_conv2d(topi.nn.conv2d_nhwc, need_auto_scheduler_layout=True), wrap_topi_schedule(topi.x86.schedule_conv2d_nhwc), name="conv2d_nhwc.x86", ) diff --git a/src/relay/op/nn/convolution.h b/src/relay/op/nn/convolution.h index 76939fc9cca9..bd0d3d34e1a8 100644 --- a/src/relay/op/nn/convolution.h +++ b/src/relay/op/nn/convolution.h @@ -143,13 +143,6 @@ bool Conv2DRel(const Array& types, int num_inputs, const Attrs& attrs, const Layout in_layout(param->data_layout); const Layout kernel_layout(param->kernel_layout); - // If the layout is rewritten by auto-scheduler, - // we just forcly apply the layout provided by auto-scheduler and - // skip the normal inference logic. - if (param->auto_scheduler_rewritten_layout.size() > 0) { - return false; - } - const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW); if (!trans_in_layout.defined()) { reporter->GetDiagCtx().Emit( @@ -219,8 +212,16 @@ bool Conv2DRel(const Array& types, int num_inputs, const Attrs& attrs, if (weight != nullptr) { weight_dtype = weight->dtype; } - // assign result to reporter - reporter->Assign(types[1], TensorType(wshape, weight_dtype)); + + if (param->auto_scheduler_rewritten_layout.size() == 0) { + // Normal case: assign result to reporter + reporter->Assign(types[1], TensorType(wshape, weight_dtype)); + } else { + // If the layout is rewritten by auto-scheduler, + // we just forcly apply the layout provided by auto-scheduler and + // skip the normal inference logic. + ; // do nothing + } } else { // use weight to infer the conv shape. if (weight == nullptr) return false; diff --git a/src/relay/transforms/auto_scheduler_layout_rewrite.cc b/src/relay/transforms/auto_scheduler_layout_rewrite.cc index 2d95f9ee4b2c..259ae8fc2a8e 100644 --- a/src/relay/transforms/auto_scheduler_layout_rewrite.cc +++ b/src/relay/transforms/auto_scheduler_layout_rewrite.cc @@ -83,8 +83,6 @@ class FuncMutator : public ExprMutator { Attrs updated_attrs; if (auto pattr = call->attrs.as()) { updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout); - } else if (auto pattr = call->attrs.as()) { - updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout); } new_n = Call(call->op, updated_args, updated_attrs); } @@ -95,8 +93,7 @@ class FuncMutator : public ExprMutator { std::deque ori_layouts_queue_; std::deque new_layouts_queue_; - std::vector target_ops_{"nn.contrib_conv2d_winograd_without_weight_transform", - "nn.conv2d"}; + std::vector target_ops_{"nn.conv2d"}; }; Expr AutoSchedulerLayoutRewriter::VisitExpr_(const CallNode* n) { From 3d492b397184e8104d9f27bb05aa070f616ad6c0 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 29 Nov 2020 10:17:44 -0800 Subject: [PATCH 03/12] fix lint --- src/relay/op/nn/convolution.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/op/nn/convolution.h b/src/relay/op/nn/convolution.h index bd0d3d34e1a8..13e87a54b9d8 100644 --- a/src/relay/op/nn/convolution.h +++ b/src/relay/op/nn/convolution.h @@ -220,7 +220,7 @@ bool Conv2DRel(const Array& types, int num_inputs, const Attrs& attrs, // If the layout is rewritten by auto-scheduler, // we just forcly apply the layout provided by auto-scheduler and // skip the normal inference logic. - ; // do nothing + {} // do nothing } } else { // use weight to infer the conv shape. From d3c69ca3af5fb5083baf48ca072812bebc81225e Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 1 Dec 2020 19:54:40 -0800 Subject: [PATCH 04/12] fix attrs --- include/tvm/relay/attrs/nn.h | 3 --- python/tvm/relay/op/strategy/generic.py | 8 ++++++-- src/relay/transforms/auto_scheduler_layout_rewrite.cc | 8 ++++++++ 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 67b719ef1dcd..f8aa1fc508b6 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -171,9 +171,6 @@ struct Conv2DAttrs : public tvm::AttrsNode { "Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "dimensions respectively. Default to be same as input layout."); - TVM_ATTR_FIELD(auto_scheduler_rewritten_layout) - .set_default("") - .describe("New kernel layout after auto-scheduler's layout rewrite."); // use 0 bits to indicate none. TVM_ATTR_FIELD(out_dtype) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index f746926880cf..e2ab3cd4f5d4 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -19,7 +19,7 @@ import logging import re -from tvm import topi +from tvm import topi, _ffi from tvm.topi.utils import get_const_int, get_const_float, get_const_tuple, get_float_tuple from tvm.target import generic_func, override_native_generic_func from .. import op as _op @@ -166,6 +166,10 @@ def schedule_bitpack(attrs, outs, target): return topi.generic.schedule_bitpack(outs) +get_auto_scheduler_rewritten_layout = _ffi.get_global_func( + "relay.attrs.get_auto_scheduler_rewritten_layout" +) + # conv2d def wrap_compute_conv2d( topi_compute, @@ -183,7 +187,7 @@ def _compute_conv2d(attrs, inputs, out_type): data_layout = attrs.get_str("data_layout") out_layout = attrs.get_str("out_layout") out_dtype = attrs.out_dtype - auto_scheduler_rewritten_layout = attrs.get_str("auto_scheduler_rewritten_layout") + auto_scheduler_rewritten_layout = get_auto_scheduler_rewritten_layout(attrs) out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype args = [inputs[0], inputs[1], strides, padding, dilation] if has_groups: diff --git a/src/relay/transforms/auto_scheduler_layout_rewrite.cc b/src/relay/transforms/auto_scheduler_layout_rewrite.cc index 259ae8fc2a8e..c9875ef5d718 100644 --- a/src/relay/transforms/auto_scheduler_layout_rewrite.cc +++ b/src/relay/transforms/auto_scheduler_layout_rewrite.cc @@ -146,6 +146,14 @@ Pass AutoSchedulerLayoutRewrite() { TVM_REGISTER_GLOBAL("relay._transform.AutoSchedulerLayoutRewrite") .set_body_typed(AutoSchedulerLayoutRewrite); +TVM_REGISTER_GLOBAL("relay.attrs.get_auto_scheduler_rewritten_layout") + .set_body_typed([](const Attrs& attrs) { + if (attrs->IsInstance()) { + return attrs.as()->auto_scheduler_rewritten_layout; + } + return std::string(); + }); + } // namespace transform } // namespace relay From 2e8f144084e09663a56de8ff63891d9e44c79b9e Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 1 Dec 2020 23:07:47 -0800 Subject: [PATCH 05/12] trigger CI --- include/tvm/ir/transform.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index 2c5b6293e6c9..56905ded5201 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -198,7 +198,7 @@ class PassContext : public ObjectRef { TVM_DLL void Trace(const IRModule& module, const PassInfo& info, bool is_before) const; /*! - * \brief Check if a pass is enabled. + * \brief Check whether a pass is enabled. * \param info The pass information. * \return true if the pass is enabled. Otherwise, false. */ From b37bd3010a1c17bb1a4ea5656fa390cb68f6a605 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Wed, 2 Dec 2020 09:03:14 -0800 Subject: [PATCH 06/12] Apply suggestions from code review --- python/tvm/topi/nn/conv2d.py | 1 + tests/python/relay/test_auto_scheduler_layout_rewrite.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py index 492b62b3e21d..8d591a20839a 100644 --- a/python/tvm/topi/nn/conv2d.py +++ b/python/tvm/topi/nn/conv2d.py @@ -381,6 +381,7 @@ def conv2d_nhwc( if auto_scheduler_rewritten_layout: # Infer shape for the rewritten layout + # todo(merrymercy): wrap this with a more general interface. if len(Filter.shape) >= 10: # For cpu tile structure SSRSRS base = len(Filter.shape) - 10 diff --git a/tests/python/relay/test_auto_scheduler_layout_rewrite.py b/tests/python/relay/test_auto_scheduler_layout_rewrite.py index c92a580a7c9b..299fcb8ebb2c 100644 --- a/tests/python/relay/test_auto_scheduler_layout_rewrite.py +++ b/tests/python/relay/test_auto_scheduler_layout_rewrite.py @@ -77,7 +77,6 @@ def tune_and_check(mod, data, weight): with tempfile.NamedTemporaryFile() as fp: log_file = fp.name - # log_file = "test_layout_rewrite.json" # Tune tasks tuner = auto_scheduler.TaskScheduler(tasks, task_weights) From 70596bd2c11274976cb38c9408f22a1b0fd5e9c0 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Wed, 2 Dec 2020 09:47:24 -0800 Subject: [PATCH 07/12] trigger CI --- include/tvm/relay/transform.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index a2ffe5e853b3..93e64e68571e 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -108,7 +108,7 @@ TVM_DLL Pass FuseOps(int fuse_opt_level = -1); /*! * \brief The inverse operation of FuseOps. It transforms a fused program returned by - * FuseOps into the program before FuseOps. (i.e., x == DefuseOps(FuseOps(x))) + * FuseOps into the program before FuseOps. (i.e. x == DefuseOps(FuseOps(x))) * * \return The pass. */ From d9bbff2ddf9cd57ba20a8c456856de045594d125 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Wed, 2 Dec 2020 11:53:22 -0800 Subject: [PATCH 08/12] Update python/tvm/auto_scheduler/relay_integration.py --- python/tvm/auto_scheduler/relay_integration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py index 9cd7b43d8067..a3e97897146f 100644 --- a/python/tvm/auto_scheduler/relay_integration.py +++ b/python/tvm/auto_scheduler/relay_integration.py @@ -330,7 +330,7 @@ def rewrite_compute_body(compute_tensor, new_layout): # Get layout free placeholders layout_free_placeholders = op.attrs["layout_free_placeholders"] - assert len(layout_free_placeholders) == 1 + assert len(layout_free_placeholders) == 1, "Only support one layout free placeholder" placeholder_op = layout_free_placeholders[0].op # Rewrite the index expression in body From f9833cfe303fb706d167e5e490eac3d7bb065ad2 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Wed, 2 Dec 2020 11:53:29 -0800 Subject: [PATCH 09/12] Update python/tvm/auto_scheduler/relay_integration.py --- python/tvm/auto_scheduler/relay_integration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py index a3e97897146f..25b88811709e 100644 --- a/python/tvm/auto_scheduler/relay_integration.py +++ b/python/tvm/auto_scheduler/relay_integration.py @@ -276,7 +276,7 @@ def auto_schedule_topi(outs, has_complex_op): dispatch_ctx = DispatchContext.current state = dispatch_ctx.query(target, key, has_complex_op, dag) if state is None: - return te.create_schedule([x.op for x in outs]) + return None # rewrite the layout and update the context for the new dag dag = ComputeDAG(outs) From 0fd4086a46253e860a89420ff11e168461c6aa28 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Wed, 2 Dec 2020 11:53:34 -0800 Subject: [PATCH 10/12] Update python/tvm/auto_scheduler/compute_dag.py --- python/tvm/auto_scheduler/compute_dag.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/auto_scheduler/compute_dag.py b/python/tvm/auto_scheduler/compute_dag.py index c1a195f3c8fe..cba3600ccf6e 100755 --- a/python/tvm/auto_scheduler/compute_dag.py +++ b/python/tvm/auto_scheduler/compute_dag.py @@ -164,7 +164,7 @@ def infer_bound_from_state(self, state): def rewrite_layout_from_state(self, state): """ - Rewrite the layout according to the transform steps in the history of a state + Rewrite the layout of the DAG according to the history transform steps of a state. Parameters ---------- From fd647807c1697594e9e55e110e9523984a296c77 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Wed, 2 Dec 2020 13:42:40 -0800 Subject: [PATCH 11/12] Trigger CI --- include/tvm/topi/transform.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index d200b1679bab..c2a4843dedd0 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -1400,7 +1400,7 @@ inline Tensor layout_transform(const Tensor& src, const std::string& src_layout, name, tag); } -/*! \brief utility function for auto_scheduler_layout_transform */ +/*! \brief Utility function for auto_scheduler_layout_transform */ inline void parse_auto_scheduler_layout(const String& layout, Array* shape, std::vector* axes) { int32_t factor = 0; From f87397ddbb696312c9f7fbc577928820ba42a154 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Wed, 2 Dec 2020 14:26:00 -0800 Subject: [PATCH 12/12] Apply suggestions from code review --- include/tvm/relay/transform.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 93e64e68571e..e4b39da85206 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -325,7 +325,7 @@ TVM_DLL Pass AlterOpLayout(); /*! * \brief Do layout rewrite according to the tile structure created by auto-scheduler. - * \return The pass. + * \return The pass */ TVM_DLL Pass AutoSchedulerLayoutRewrite();