Skip to content

Commit

Permalink
[AutoScheduler] Add layout rewrite pass in relay
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Nov 29, 2020
1 parent 9a9ec1a commit dad8cc7
Show file tree
Hide file tree
Showing 25 changed files with 721 additions and 65 deletions.
7 changes: 7 additions & 0 deletions include/tvm/ir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,13 @@ class PassContext : public ObjectRef {
*/
TVM_DLL void Trace(const IRModule& module, const PassInfo& info, bool is_before) const;

/*!
* \brief Check if a pass is enabled.
* \param info The pass information.
* \return true if the pass is enabled. Otherwise, false.
*/
TVM_DLL bool PassEnabled(const PassInfo& info) const;

/*!
* \brief Register a valid configuration option and its ValueType for validation.
*
Expand Down
27 changes: 27 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
tvm::String data_layout;
tvm::String kernel_layout;
tvm::String out_layout;
std::string auto_scheduler_rewritten_layout;
DataType out_dtype;

TVM_DECLARE_ATTRS(Conv2DAttrs, "relay.attrs.Conv2DAttrs") {
Expand Down Expand Up @@ -170,6 +171,9 @@ struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
"Dimension ordering of output. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Default to be same as input layout.");
TVM_ATTR_FIELD(auto_scheduler_rewritten_layout)
.set_default("")
.describe("New kernel layout after auto-scheduler's layout rewrite.");

// use 0 bits to indicate none.
TVM_ATTR_FIELD(out_dtype)
Expand Down Expand Up @@ -212,6 +216,7 @@ struct Conv2DWinogradAttrs : public tvm::AttrsNode<Conv2DWinogradAttrs> {
std::string data_layout;
std::string kernel_layout;
std::string out_layout;
std::string auto_scheduler_rewritten_layout;
DataType out_dtype;

TVM_DECLARE_ATTRS(Conv2DWinogradAttrs, "relay.attrs.Conv2DWinogradAttrs") {
Expand Down Expand Up @@ -264,6 +269,9 @@ struct Conv2DWinogradAttrs : public tvm::AttrsNode<Conv2DWinogradAttrs> {
"Dimension ordering of output. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Default to be same as input layout.");
TVM_ATTR_FIELD(auto_scheduler_rewritten_layout)
.set_default("")
.describe("New kernel layout after auto-scheduler's layout rewrite.");

// use 0 bits to indicate none.
TVM_ATTR_FIELD(out_dtype)
Expand Down Expand Up @@ -302,6 +310,7 @@ struct Conv3DAttrs : public tvm::AttrsNode<Conv3DAttrs> {
std::string data_layout;
std::string kernel_layout;
std::string out_layout;
std::string auto_scheduler_rewritten_layout;
DataType out_dtype;

TVM_DECLARE_ATTRS(Conv3DAttrs, "relay.attrs.Conv3DAttrs") {
Expand Down Expand Up @@ -353,6 +362,9 @@ struct Conv3DAttrs : public tvm::AttrsNode<Conv3DAttrs> {
"Dimension ordering of output. Can be 'NCDHW', 'NDHWC', etc."
"'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
"dimensions respectively. Default to be same as input layout.");
TVM_ATTR_FIELD(auto_scheduler_rewritten_layout)
.set_default("")
.describe("New kernel layout after auto-scheduler's layout rewrite.");

// use 0 bits to indicate none.
TVM_ATTR_FIELD(out_dtype)
Expand Down Expand Up @@ -923,10 +935,14 @@ struct AvgPool3DAttrs : public tvm::AttrsNode<AvgPool3DAttrs> {
/*! \brief Attributes for dense operator */
struct DenseAttrs : public tvm::AttrsNode<DenseAttrs> {
IndexExpr units;
std::string auto_scheduler_rewritten_layout;
DataType out_dtype;

TVM_DECLARE_ATTRS(DenseAttrs, "relay.attrs.DenseAttrs") {
TVM_ATTR_FIELD(units).describe("Number of hidden units of the dense transformation.");
TVM_ATTR_FIELD(auto_scheduler_rewritten_layout)
.set_default("")
.describe("New kernel layout after auto-scheduler's layout rewrite.");

// use 0 bits to indicate none.
TVM_ATTR_FIELD(out_dtype)
Expand All @@ -935,6 +951,17 @@ struct DenseAttrs : public tvm::AttrsNode<DenseAttrs> {
}
};

/*! \brief Attributes for batch matmul operator */
struct BatchMatmulAttrs : public tvm::AttrsNode<BatchMatmulAttrs> {
std::string auto_scheduler_rewritten_layout;

TVM_DECLARE_ATTRS(BatchMatmulAttrs, "relay.attrs.BatchNormAttrs") {
TVM_ATTR_FIELD(auto_scheduler_rewritten_layout)
.set_default("")
.describe("New kernel layout after auto-scheduler's layout rewrite.");
}
};

/*! \brief Attributes for sparse_dense operator */
struct SparseDenseAttrs : public tvm::AttrsNode<SparseDenseAttrs> {
TVM_DECLARE_ATTRS(SparseDenseAttrs, "relay.attrs.SparseDenseAttrs") {}
Expand Down
14 changes: 14 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,20 @@ struct LayoutTransformAttrs : public tvm::AttrsNode<LayoutTransformAttrs> {
}
};

/*! \brief Attributes for AutoSchedulerLayoutTransform operator */
struct AutoSchedulerLayoutTransformAttrs
: public tvm::AttrsNode<AutoSchedulerLayoutTransformAttrs> {
std::string src_layout;
std::string dst_layout;

TVM_DECLARE_ATTRS(AutoSchedulerLayoutTransformAttrs,
"relay.attrs.AutoSchedulerLayoutTransformAttrs") {
TVM_ATTR_FIELD(src_layout).describe("The source layout of the tensor. (e.g. 1N32C112H112W)");
TVM_ATTR_FIELD(dst_layout)
.describe("The destination layout of the tensor. (e.g. 1N2C112H112W16c)");
}
};

/*! \brief Attributes for ShapeOf operator */
struct ShapeOfAttrs : public tvm::AttrsNode<ShapeOfAttrs> {
DataType dtype;
Expand Down
14 changes: 14 additions & 0 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,14 @@ TVM_DLL Pass FoldConstant();
*/
TVM_DLL Pass FuseOps(int fuse_opt_level = -1);

/*!
* \brief The inverse operation of FuseOps. It transforms a fused program returned by
* FuseOps into the program before FuseOps. (i.e., x == DefuseOps(FuseOps(x)))
*
* \return The pass.
*/
TVM_DLL Pass DefuseOps();

/*!
* \brief Rewrite the annotated program.
*
Expand Down Expand Up @@ -315,6 +323,12 @@ TVM_DLL Pass CanonicalizeOps();
*/
TVM_DLL Pass AlterOpLayout();

/*!
* \brief Do layout rewrite according to the tile structure created by auto-scheduler.
* \return The pass.
*/
TVM_DLL Pass AutoSchedulerLayoutRewrite();

/*!
* \brief Given a dest layout, this pass transforms the expr such that most of the ops input data
* layout is changed to the dest layout. In ideal situation, there are only 2 layout transforms, one
Expand Down
68 changes: 68 additions & 0 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -1400,6 +1400,74 @@ inline Tensor layout_transform(const Tensor& src, const std::string& src_layout,
name, tag);
}

/*! \brief utility function for auto_scheduler_layout_transform */
inline void parse_auto_scheduler_layout(const String& layout, Array<PrimExpr>* shape,
std::vector<std::string>* axes) {
int32_t factor = 0;
std::string axis = "";
for (char c : std::string(layout)) {
if (c >= 'A' && c <= 'z') {
axis += c;
if (factor != 0) {
shape->push_back(factor);
factor = 0;
}
} else if (c >= '0' && c <= '9') {
factor = factor * 10 + c - '0';
if (!axis.empty()) {
axes->push_back(axis);
axis = "";
}
} else {
LOG(FATAL) << "Invalid layout " << layout;
}
}
if (!axis.empty()) {
axes->push_back(axis);
}
}

/*!
* \brief Transform the auto-scheduler generated layout according to
* \p src_layout and \p dst_layout
* \param src the source input.
* \param src_layout the source layout.
* \param dst_layout the destination layout.
* \param name output tensor name.
* \param tag output tensor tag.
* \return A tensor with shape in \p dst_layout
*/
inline Tensor auto_scheduler_layout_transform(const Tensor& src, const String& src_layout,
const String& dst_layout,
const String name = "T_auto_scheduler_layout_trans",
const String tag = kInjective) {
Array<PrimExpr> src_shape;
std::vector<std::string> src_axes;
Array<PrimExpr> dst_shape;
std::vector<std::string> dst_axes;

parse_auto_scheduler_layout(src_layout, &src_shape, &src_axes);
parse_auto_scheduler_layout(dst_layout, &dst_shape, &dst_axes);
return compute(
dst_shape,
[&](const Array<Var>& dst_indices) {
Array<PrimExpr> dst_indices_expr(dst_indices.begin(), dst_indices.end());
Array<PrimExpr> src_indices;
for (const std::string& src_axis : src_axes) {
PrimExpr src_index = 0;
CHECK_EQ(dst_indices_expr.size(), dst_axes.size());
for (size_t i = 0; i < dst_axes.size(); ++i) {
if (dst_axes[i] == src_axis) {
src_index = src_index * dst_shape[i] + dst_indices_expr[i];
}
}
src_indices.push_back(src_index);
}
return src(src_indices);
},
name, tag);
}

/*!
* \brief Get the shape of input tensor.
* \param src the input tensor.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/auto_scheduler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

# Shortcut
from .auto_schedule import TuningOptions, HardwareParams, create_task, auto_schedule
from .compute_dag import ComputeDAG
from .compute_dag import ComputeDAG, rewrite_compute_body
from .cost_model import RandomModel, XGBModel
from .dispatcher import DispatchContext, ApplyHistoryBest
from .measure import (
Expand Down
30 changes: 30 additions & 0 deletions python/tvm/auto_scheduler/compute_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,22 @@ def infer_bound_from_state(self, state):
updated_state.stage_id_map[k] = v
return updated_state

def rewrite_layout_from_state(self, state):
"""
Rewrite the layout according to the transform steps in the history of a state
Parameters
----------
state : Union[State, StateObject]
The state from which we get transform steps.
Returns
-------
updated_state : StateObject
"""
state_obj = state if isinstance(state, StateObject) else state.state_object
return _ffi_api.ComputeDAGRewriteLayoutFromState(self, state_obj)

def hash_key(self):
"""Return the hash key of this compute DAG.
Expand Down Expand Up @@ -210,3 +226,17 @@ def __setstate__(self, state):
self.compute = LoadJSON(state["compute"]) # pylint: disable=assignment-from-no-return
self.sche = LoadJSON(state["sche"]) # pylint: disable=assignment-from-no-return
self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, self.compute, self.sche)


def rewrite_compute_body(compute, placeholder, new_layout):
"""Rewrite the body of a ComputeOp according to a new layout of a placeholder"""
body = []
for b in compute.op.body:
body.append(_ffi_api.RewriteIndexForNewLayout(placeholder.op, new_layout, b))
op_node = tvm.te._ffi_api.ComputeOp(
compute.op.name, compute.op.tag, compute.op.attrs, compute.op.axis, body
)

num = op_node.num_outputs
outputs = tuple(op_node.output(i) for i in range(num))
return outputs[0] if num == 1 else outputs
44 changes: 38 additions & 6 deletions python/tvm/auto_scheduler/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"""

import logging
import json
import threading

import tvm
Expand All @@ -46,7 +47,11 @@ def call_all_topi_funcs(mod, params, target):
old_autotvm_silent = autotvm.GLOBAL_SCOPE.silent
autotvm.GLOBAL_SCOPE.silent = True

with transform.PassContext(opt_level=3, config={"relay.backend.use_auto_scheduler": True}):
with transform.PassContext(
opt_level=3,
config={"relay.backend.use_auto_scheduler": True},
disabled_pass={"AutoSchedulerLayoutRewrite"},
):
opt_mod, _ = relay.optimize(mod, target, params)
grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target)
grc.codegen(opt_mod["main"])
Expand Down Expand Up @@ -158,6 +163,18 @@ def add_workload_key(self, workload_key, ccache_key):
self.wkl_key_to_ccache_key[workload_key] = ccache_key


@tvm._ffi.register_func("auto_scheduler.enter_layout_rewrite")
def _enter_layout_rewrite(*args):
env = TracingEnvironment(TracingMode.PREPARE_LAYOUT_REWRITE)
env.__enter__()


@tvm._ffi.register_func("auto_scheduler.exit_layout_rewrite")
def _exit_layout_rewrite(*args):
env = TracingEnvironment.current
env.__exit__(None, None, None)


def traverse_to_get_io_tensors(outs):
"""Traverse from a list of output tensors to get both input and output tensors
Expand Down Expand Up @@ -230,11 +247,13 @@ def auto_schedule_topi(outs, has_complex_op):
key = register_workload_tensors(dag.hash_key(), io_tensors)

# only enable layout rewrite for cpu backend
enable_layout_rewrite = "cpu" in tvm.target.Target.current().keys
target = tvm.target.Target.current()
enable_layout_rewrite = "cpu" in target.keys

env = TracingEnvironment.current
if env is None: # in the final build mode
state = DispatchContext.current.query(tvm.target.Target.current(), key, has_complex_op, dag)
if env is None:
# in the final build mode
state = DispatchContext.current.query(target, key, has_complex_op, dag)
if state is None:
return None

Expand All @@ -247,8 +266,21 @@ def auto_schedule_topi(outs, has_complex_op):
env.add_workload_key(key, ccache_key)
schedule = te.create_schedule([x.op for x in outs])
elif env.tracing_mode == TracingMode.PREPARE_LAYOUT_REWRITE:
# todo(merrymercy, minminsun): port layout rewrite
raise NotImplementedError
# in prepare_layout_rewrite mode
if enable_layout_rewrite and has_layout_free:
dispatch_ctx = DispatchContext.current
state = dispatch_ctx.query(target, key, has_complex_op, dag)
if state is None:
return te.create_schedule([x.op for x in outs])

# rewrite the layout and update the context
# for the new dag
dag = ComputeDAG(outs)
new_dag = dag.rewrite_layout_from_state(state)
new_key = json.dumps((new_dag.hash_key(),))
if new_key != key:
dispatch_ctx.update(target, new_key, state)
return te.create_schedule([x.op for x in outs])
else:
raise ValueError("Invalid tracing mode: " + env.tracing_mode)

Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ def compute_strided_set(attrs, inputs, output_type):
# layout_transform
_reg.register_injective_schedule("layout_transform")
_reg.register_pattern("layout_transform", OpPattern.INJECTIVE)
_reg.register_injective_schedule("auto_scheduler_layout_transform")
_reg.register_pattern("auto_scheduler_layout_transform", OpPattern.INJECTIVE)

# argwhere
@_reg.register_compute("argwhere")
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def _compute_conv2d(attrs, inputs, out_type):
data_layout = attrs.get_str("data_layout")
out_layout = attrs.get_str("out_layout")
out_dtype = attrs.out_dtype
auto_scheduler_rewritten_layout = attrs.get_str("auto_scheduler_rewritten_layout")
out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
args = [inputs[0], inputs[1], strides, padding, dilation]
if has_groups:
Expand All @@ -188,6 +189,7 @@ def _compute_conv2d(attrs, inputs, out_type):
if need_out_layout:
args.append(out_layout)
args.append(out_dtype)
args.append(auto_scheduler_rewritten_layout)
return [topi_compute(*args)]

return _compute_conv2d
Expand Down
1 change: 0 additions & 1 deletion python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target):
return conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target)
elif layout == "NHWC":
assert kernel_layout == "HWIO"
logger.warning("For x86 target, NCHW layout is recommended for conv2d.")
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_nhwc),
wrap_topi_schedule(topi.x86.schedule_conv2d_nhwc),
Expand Down
Loading

0 comments on commit dad8cc7

Please sign in to comment.