Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoScheduler] Support layout rewrite for whole networks #6987

Merged
merged 12 commits into from
Dec 3, 2020
7 changes: 7 additions & 0 deletions include/tvm/ir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,13 @@ class PassContext : public ObjectRef {
*/
TVM_DLL void Trace(const IRModule& module, const PassInfo& info, bool is_before) const;

/*!
* \brief Check whether a pass is enabled.
* \param info The pass information.
* \return true if the pass is enabled. Otherwise, false.
*/
TVM_DLL bool PassEnabled(const PassInfo& info) const;

/*!
* \brief Register a valid configuration option and its ValueType for validation.
*
Expand Down
1 change: 1 addition & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
tvm::String data_layout;
tvm::String kernel_layout;
tvm::String out_layout;
std::string auto_scheduler_rewritten_layout;
DataType out_dtype;

TVM_DECLARE_ATTRS(Conv2DAttrs, "relay.attrs.Conv2DAttrs") {
Expand Down
14 changes: 14 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,20 @@ struct LayoutTransformAttrs : public tvm::AttrsNode<LayoutTransformAttrs> {
}
};

/*! \brief Attributes for AutoSchedulerLayoutTransform operator */
struct AutoSchedulerLayoutTransformAttrs
: public tvm::AttrsNode<AutoSchedulerLayoutTransformAttrs> {
std::string src_layout;
std::string dst_layout;

TVM_DECLARE_ATTRS(AutoSchedulerLayoutTransformAttrs,
"relay.attrs.AutoSchedulerLayoutTransformAttrs") {
TVM_ATTR_FIELD(src_layout).describe("The source layout of the tensor. (e.g. 1N32C112H112W)");
TVM_ATTR_FIELD(dst_layout)
.describe("The destination layout of the tensor. (e.g. 1N2C112H112W16c)");
}
};

/*! \brief Attributes for ShapeOf operator */
struct ShapeOfAttrs : public tvm::AttrsNode<ShapeOfAttrs> {
DataType dtype;
Expand Down
14 changes: 14 additions & 0 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,14 @@ TVM_DLL Pass FoldConstant();
*/
TVM_DLL Pass FuseOps(int fuse_opt_level = -1);

/*!
* \brief The inverse operation of FuseOps. It transforms a fused program returned by
* FuseOps into the program before FuseOps. (i.e. x == DefuseOps(FuseOps(x)))
*
* \return The pass.
*/
TVM_DLL Pass DefuseOps();

/*!
* \brief Rewrite the annotated program.
*
Expand Down Expand Up @@ -315,6 +323,12 @@ TVM_DLL Pass CanonicalizeOps();
*/
TVM_DLL Pass AlterOpLayout();

/*!
* \brief Do layout rewrite according to the tile structure created by auto-scheduler.
* \return The pass
*/
TVM_DLL Pass AutoSchedulerLayoutRewrite();

/*!
* \brief Given a dest layout, this pass transforms the expr such that most of the ops input data
* layout is changed to the dest layout. In ideal situation, there are only 2 layout transforms, one
Expand Down
68 changes: 68 additions & 0 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -1400,6 +1400,74 @@ inline Tensor layout_transform(const Tensor& src, const std::string& src_layout,
name, tag);
}

/*! \brief Utility function for auto_scheduler_layout_transform */
inline void parse_auto_scheduler_layout(const String& layout, Array<PrimExpr>* shape,
std::vector<std::string>* axes) {
int32_t factor = 0;
std::string axis = "";
for (char c : std::string(layout)) {
if (c >= 'A' && c <= 'z') {
axis += c;
if (factor != 0) {
shape->push_back(factor);
factor = 0;
}
} else if (c >= '0' && c <= '9') {
factor = factor * 10 + c - '0';
if (!axis.empty()) {
axes->push_back(axis);
axis = "";
}
} else {
LOG(FATAL) << "Invalid layout " << layout;
}
}
if (!axis.empty()) {
axes->push_back(axis);
}
}

/*!
* \brief Transform the auto-scheduler generated layout according to
* \p src_layout and \p dst_layout
* \param src the source input.
* \param src_layout the source layout.
* \param dst_layout the destination layout.
* \param name output tensor name.
* \param tag output tensor tag.
* \return A tensor with shape in \p dst_layout
*/
inline Tensor auto_scheduler_layout_transform(const Tensor& src, const String& src_layout,
const String& dst_layout,
const String name = "T_auto_scheduler_layout_trans",
const String tag = kInjective) {
Array<PrimExpr> src_shape;
std::vector<std::string> src_axes;
Array<PrimExpr> dst_shape;
std::vector<std::string> dst_axes;

parse_auto_scheduler_layout(src_layout, &src_shape, &src_axes);
parse_auto_scheduler_layout(dst_layout, &dst_shape, &dst_axes);
return compute(
dst_shape,
[&](const Array<Var>& dst_indices) {
Array<PrimExpr> dst_indices_expr(dst_indices.begin(), dst_indices.end());
Array<PrimExpr> src_indices;
for (const std::string& src_axis : src_axes) {
PrimExpr src_index = 0;
CHECK_EQ(dst_indices_expr.size(), dst_axes.size());
for (size_t i = 0; i < dst_axes.size(); ++i) {
if (dst_axes[i] == src_axis) {
src_index = src_index * dst_shape[i] + dst_indices_expr[i];
}
}
src_indices.push_back(src_index);
}
return src(src_indices);
},
name, tag);
}

/*!
* \brief Get the shape of input tensor.
* \param src the input tensor.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/auto_scheduler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
LocalRPCMeasureContext,
)
from .measure_record import RecordToFile, RecordReader, load_best, load_records, save_records
from .relay_integration import extract_tasks
from .relay_integration import extract_tasks, remove_index_check, rewrite_compute_body
from .search_task import SearchTask
from .search_policy import EmptyPolicy, SketchPolicy, PreloadMeasuredStates
from .task_scheduler import TaskScheduler
Expand Down
17 changes: 17 additions & 0 deletions python/tvm/auto_scheduler/compute_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,23 @@ def infer_bound_from_state(self, state):
updated_state.stage_id_map[k] = v
return updated_state

def rewrite_layout_from_state(self, state):
comaniac marked this conversation as resolved.
Show resolved Hide resolved
"""
Rewrite the layout of the DAG according to the history transform steps of a state.

Parameters
----------
state : Union[State, StateObject]
The state from which we get transform steps.

Returns
-------
updated_dag : ComputeDAG
The compute dag with rewritten layout.
"""
state_obj = state if isinstance(state, StateObject) else state.state_object
return _ffi_api.ComputeDAGRewriteLayoutFromState(self, state_obj)

def hash_key(self):
"""Return the hash key of this compute DAG.

Expand Down
4 changes: 3 additions & 1 deletion python/tvm/auto_scheduler/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,9 @@ def _timed_func(inp_serialized, build_func, verbose):
args = []

try:
sch, args = task.compute_dag.apply_steps_from_state(inp.state, layout_rewrite=True)
sch, args = task.compute_dag.apply_steps_from_state(
inp.state, layout_rewrite=ComputeDAG.RewriteForPreTransformed
)
# pylint: disable=broad-except
except Exception:
error_no = MeasureErrorNo.INSTANTIATION_ERROR
Expand Down
103 changes: 96 additions & 7 deletions python/tvm/auto_scheduler/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,15 @@
"""

import logging
import json
import threading

import tvm
from tvm import autotvm, te, transform
from tvm.te.tensor import ComputeOp, PlaceholderOp
from tvm.runtime import convert_to_object
from tvm.te.tensor import ComputeOp, PlaceholderOp, Tensor
from tvm.tir import expr as _expr
from . import _ffi_api
from .compute_dag import ComputeDAG
from .dispatcher import DispatchContext
from .search_task import SearchTask
Expand All @@ -46,7 +50,11 @@ def call_all_topi_funcs(mod, params, target):
old_autotvm_silent = autotvm.GLOBAL_SCOPE.silent
autotvm.GLOBAL_SCOPE.silent = True

with transform.PassContext(opt_level=3, config={"relay.backend.use_auto_scheduler": True}):
with transform.PassContext(
opt_level=3,
config={"relay.backend.use_auto_scheduler": True},
disabled_pass={"AutoSchedulerLayoutRewrite"},
):
opt_mod, _ = relay.optimize(mod, target, params)
grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target)
grc.codegen(opt_mod["main"])
Expand Down Expand Up @@ -158,6 +166,20 @@ def add_workload_key(self, workload_key, ccache_key):
self.wkl_key_to_ccache_key[workload_key] = ccache_key


@tvm._ffi.register_func("auto_scheduler.enter_layout_rewrite")
def enter_layout_rewrite():
"""Enter layout rewrite tracing environment"""
env = TracingEnvironment(TracingMode.PREPARE_LAYOUT_REWRITE)
env.__enter__()


@tvm._ffi.register_func("auto_scheduler.exit_layout_rewrite")
def exit_layout_rewrite():
"""Exit layout rewrite tracing environment"""
env = TracingEnvironment.current
env.__exit__(None, None, None)


def traverse_to_get_io_tensors(outs):
"""Traverse from a list of output tensors to get both input and output tensors

Expand Down Expand Up @@ -230,11 +252,13 @@ def auto_schedule_topi(outs, has_complex_op):
key = register_workload_tensors(dag.hash_key(), io_tensors)

# only enable layout rewrite for cpu backend
enable_layout_rewrite = "cpu" in tvm.target.Target.current().keys
target = tvm.target.Target.current()
enable_layout_rewrite = "cpu" in target.keys

env = TracingEnvironment.current
if env is None: # in the final build mode
state = DispatchContext.current.query(tvm.target.Target.current(), key, has_complex_op, dag)
if env is None:
# in the final build mode
state = DispatchContext.current.query(target, key, has_complex_op, dag)
if state is None:
return None

Expand All @@ -247,9 +271,74 @@ def auto_schedule_topi(outs, has_complex_op):
env.add_workload_key(key, ccache_key)
schedule = te.create_schedule([x.op for x in outs])
elif env.tracing_mode == TracingMode.PREPARE_LAYOUT_REWRITE:
# todo(merrymercy, minminsun): port layout rewrite
raise NotImplementedError
# in prepare_layout_rewrite mode
if enable_layout_rewrite and has_layout_free:
dispatch_ctx = DispatchContext.current
state = dispatch_ctx.query(target, key, has_complex_op, dag)
if state is None:
return None

# rewrite the layout and update the context for the new dag
dag = ComputeDAG(outs)
new_dag = dag.rewrite_layout_from_state(state)
new_key = json.dumps((new_dag.hash_key(),))
if new_key != key:
dispatch_ctx.update(target, new_key, state)
return te.create_schedule([x.op for x in outs])
else:
raise ValueError("Invalid tracing mode: " + env.tracing_mode)

return schedule


def tensor_no_check_call(self, *indices):
"""An indexing function without any check.
This is the same as `tvm.te.Tensor::__call__` except that the safety
check is removed.
"""
indices = convert_to_object(indices)
args = []
for x in indices:
if isinstance(x, _expr.PrimExpr):
args.append(x)
elif isinstance(x, _expr.IterVar):
args.append(x.var)
else:
raise ValueError("The indices must be expression")

return _expr.ProducerLoad(self, args)


def remove_index_check(tensor):
"""Remove the safety check in the indexing function for a tensor.
This is done by monkey patching its indexing function.
After removing the check, we are allowed to create a
temporary wrong IR and fix it later in other places.

Parameters
----------
tensor: Tensor
The tensor to remove index check.
"""
# Monkey patch the indexing function
tensor.__call__ = tensor_no_check_call.__get__(tensor, Tensor)


def rewrite_compute_body(compute_tensor, new_layout):
"""Rewrite the body of a ComputeOp according to a new layout of a placeholder"""
op = compute_tensor.op

# Get layout free placeholders
layout_free_placeholders = op.attrs["layout_free_placeholders"]
assert len(layout_free_placeholders) == 1, "Only support one layout free placeholder"
placeholder_op = layout_free_placeholders[0].op

# Rewrite the index expression in body
body = []
for b in op.body:
body.append(_ffi_api.RewriteIndexForNewLayout(placeholder_op, new_layout, b))
op_node = tvm.te._ffi_api.ComputeOp(op.name, op.tag, op.attrs, op.axis, body)

num = op_node.num_outputs
outputs = tuple(op_node.output(i) for i in range(num))
return outputs[0] if num == 1 else outputs
2 changes: 2 additions & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ def compute_strided_set(attrs, inputs, output_type):
# layout_transform
_reg.register_injective_schedule("layout_transform")
_reg.register_pattern("layout_transform", OpPattern.INJECTIVE)
_reg.register_injective_schedule("auto_scheduler_layout_transform")
_reg.register_pattern("auto_scheduler_layout_transform", OpPattern.INJECTIVE)

# argwhere
@_reg.register_compute("argwhere")
Expand Down
15 changes: 13 additions & 2 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import logging

import re
from tvm import topi
from tvm import topi, _ffi
from tvm.topi.utils import get_const_int, get_const_float, get_const_tuple, get_float_tuple
from tvm.target import generic_func, override_native_generic_func
from .. import op as _op
Expand Down Expand Up @@ -166,9 +166,17 @@ def schedule_bitpack(attrs, outs, target):
return topi.generic.schedule_bitpack(outs)


get_auto_scheduler_rewritten_layout = _ffi.get_global_func(
"relay.attrs.get_auto_scheduler_rewritten_layout"
)

# conv2d
def wrap_compute_conv2d(
topi_compute, need_data_layout=False, need_out_layout=False, has_groups=False
topi_compute,
need_data_layout=False,
need_out_layout=False,
has_groups=False,
need_auto_scheduler_layout=False,
):
"""Wrap conv2d topi compute"""

Expand All @@ -179,6 +187,7 @@ def _compute_conv2d(attrs, inputs, out_type):
data_layout = attrs.get_str("data_layout")
out_layout = attrs.get_str("out_layout")
out_dtype = attrs.out_dtype
auto_scheduler_rewritten_layout = get_auto_scheduler_rewritten_layout(attrs)
out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
args = [inputs[0], inputs[1], strides, padding, dilation]
if has_groups:
Expand All @@ -188,6 +197,8 @@ def _compute_conv2d(attrs, inputs, out_type):
if need_out_layout:
args.append(out_layout)
args.append(out_dtype)
if need_auto_scheduler_layout:
args.append(auto_scheduler_rewritten_layout)
return [topi_compute(*args)]

return _compute_conv2d
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,8 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target):
return conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target)
elif layout == "NHWC":
assert kernel_layout == "HWIO"
logger.warning("For x86 target, NCHW layout is recommended for conv2d.")
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_nhwc),
wrap_compute_conv2d(topi.nn.conv2d_nhwc, need_auto_scheduler_layout=True),
wrap_topi_schedule(topi.x86.schedule_conv2d_nhwc),
name="conv2d_nhwc.x86",
)
Expand Down
Loading