Skip to content

Commit

Permalink
[TIR] add support for multi-blocking layout and their transformation (#…
Browse files Browse the repository at this point in the history
…9996)

* add ceildiv and shapediv

* add boundary checking in layout_transform

* support multi-blocking and shape padding

* refine the log for shape transform

* add test for multi-blocking layout transform

* delete unwanted comments

* remove workaround

* fix lint errors
  • Loading branch information
yangulei authored Feb 21, 2022
1 parent 73cf51b commit 8d76075
Show file tree
Hide file tree
Showing 7 changed files with 185 additions and 42 deletions.
14 changes: 10 additions & 4 deletions include/tvm/tir/data_layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,9 +295,13 @@ class BijectiveLayoutNode : public Object {
/*! \brief Describes how source axes can be mapped to the destination axes,
* e.g., [i0 / 16, i1, i0 % 16] can describe NC -> NC16n
*/
Array<PrimExpr> forward_rule;
Array<PrimExpr> index_forward_rule;
/*! \brief Describes how destination axes can be mapped to the source axes */
Array<PrimExpr> backward_rule;
Array<PrimExpr> index_backward_rule;
/*! \brief Describes how source shapes can be mapped to the destination shapes */
Array<PrimExpr> shape_forward_rule;
/*! \brief Describes how destination shapes can be mapped to the source shapes */
Array<PrimExpr> shape_backward_rule;

/*! \brief The source layout */
Layout src_layout;
Expand All @@ -307,8 +311,10 @@ class BijectiveLayoutNode : public Object {
void VisitAttrs(AttrVisitor* v) {
v->Visit("src_layout", &src_layout);
v->Visit("dst_layout", &dst_layout);
v->Visit("forward_rule", &forward_rule);
v->Visit("backward_rule", &backward_rule);
v->Visit("index_forward_rule", &index_forward_rule);
v->Visit("index_backward_rule", &index_backward_rule);
v->Visit("shape_forward_rule", &shape_forward_rule);
v->Visit("shape_backward_rule", &shape_backward_rule);
}

static constexpr const char* _type_key = "tir.BijectiveLayout";
Expand Down
27 changes: 27 additions & 0 deletions include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,22 @@ TVM_DLL PrimExpr truncmod(PrimExpr a, PrimExpr b, Span span = Span());
* index types(int32, int64) when possible.
*/
TVM_DLL PrimExpr indexdiv(PrimExpr a, PrimExpr b, Span span = Span());
/*!
* \brief compute ceil(a / b) where a and b are non-negative.
*
* Use this function for shape split calculation.
*
* This function might take advantage of the fact
* that a and b are non-negative.
*
* \param a left operand
* \param b right operand
* \param span The location of this operation in the source.
* \return The result expression.
* \note this function does eager constant folding for
* shape types(int32, int64) when possible.
*/
TVM_DLL PrimExpr shapediv(PrimExpr a, PrimExpr b, Span span = Span());
/*!
* \brief compute the remainder floor(a / b) where a and b are non-negative.
*
Expand All @@ -521,6 +537,17 @@ TVM_DLL PrimExpr indexmod(PrimExpr a, PrimExpr b, Span span = Span());
* index types(int32, int64) when possible.
*/
TVM_DLL PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span = Span());
/*!
* \brief compute ceil(a / b)
*
* \param a left operand
* \param b right operand
* \param span The location of this operation in the source.
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL PrimExpr ceildiv(PrimExpr a, PrimExpr b, Span span = Span());
/*!
* \brief compute the remainder of floordiv
*
Expand Down
6 changes: 5 additions & 1 deletion include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -1608,7 +1608,11 @@ inline Tensor layout_transform(const Tensor& src, const std::string& src_layout,
[&](const Array<Var>& dst_indices) {
Array<PrimExpr> dst_indices_expr(dst_indices.begin(), dst_indices.end());
Array<PrimExpr> src_indices = layout_converter.BackwardIndex(dst_indices_expr);
return src(src_indices);
PrimExpr in_range = PrimExpr(1) > PrimExpr(0); // init with dtype=bool and value=true
for (size_t i = 0; i < src.ndim(); ++i) {
in_range = in_range && (src_indices[i] < src->shape[i]);
}
return if_then_else(in_range, src(src_indices), tvm::cast(src->dtype, PrimExpr(0)));
},
name, tag);
}
Expand Down
19 changes: 3 additions & 16 deletions python/tvm/topi/x86/conv2d_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,25 +164,12 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape)
out_channel, channel_multiplier, kh, kw = get_const_tuple(kernel_tensor.shape)
ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
n_elems = 4

# convert kernel data layout from 4D to 7D
data_expr, kernel_expr = inputs
kernel_IHWO = relay.transpose(kernel_expr, axes=(1, 2, 3, 0))
kernel_IHWOo = relay.reshape(kernel_IHWO, (in_channel, kh, kw, out_channel // oc_bn, oc_bn))
kernel_OHWoI = relay.transpose(kernel_IHWOo, axes=(3, 1, 2, 4, 0))
kernel_OHWoIi = relay.reshape(
kernel_OHWoI, (out_channel // oc_bn, kh, kw, oc_bn, in_channel // ic_bn, ic_bn)
)
kernel_OHWoIie = relay.reshape(
kernel_OHWoIi,
(out_channel // oc_bn, kh, kw, oc_bn, in_channel // ic_bn, ic_bn // n_elems, n_elems),
)
kernel_OIHWioe = relay.transpose(kernel_OHWoIie, axes=(0, 4, 1, 2, 5, 3, 6))

# update new attrs
n_elems = 4
new_attrs["channels"] = out_channel
new_attrs["data_layout"] = "NCHW%dc" % ic_bn
new_attrs["kernel_layout"] = "OIHW{:n}i{:n}o{:n}i".format(ic_bn // n_elems, oc_bn, n_elems)
new_attrs["out_layout"] = "NCHW%dc" % oc_bn

# Store altered operator's config.
Expand All @@ -208,7 +195,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
)
dispatch_ctx.update(target, new_workload, cfg)

return relay.nn.contrib_conv2d_nchwc(data_expr, kernel_OIHWioe, **new_attrs)
return relay.nn.contrib_conv2d_nchwc(*inputs, **new_attrs)

if topi_tmpl == "depthwise_conv2d_NCHWc.x86":
if data_layout == "NCHW" and kernel_layout == "OIHW":
Expand Down
104 changes: 83 additions & 21 deletions src/tir/ir/data_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ Layout::Layout(const std::string& name) { // NOLINT(*)
ICHECK_EQ(axis_str.size(), 1);
char axis = axis_str[0];
ICHECK((axis >= 'a' && axis <= 'z') || (axis >= 'A' && axis <= 'Z'));
ICHECK(!exist_axis[axis]) << "Invalid layout " << name << ": duplicate axis " << axis;
exist_axis[axis] = true;
}
for (const IterVar& v : node->axes) {
Expand Down Expand Up @@ -182,15 +181,20 @@ Layout Layout::Split(const LayoutAxis& axis, size_t target_pos, int32_t factor)
int32_t Layout::FactorOf(const LayoutAxis& axis) const {
if (!defined()) return -1;
const LayoutAxis& sub = axis.ToSubordinate();
if (!this->defined()) return -1;

int32_t factor = 1;
bool has_sub = false;
for (const IterVar& itvar : operator->()->axes) {
if (sub == LayoutAxis::Get(itvar)) {
const auto* factor = itvar->dom->extent.as<IntImmNode>();
ICHECK(factor);
return factor->value;
has_sub = true;
int32_t val = itvar->dom->extent.as<IntImmNode>()->value;
ICHECK(val);
factor *= val;
}
}
return -1;
factor = has_sub ? factor : -1;

return factor;
}

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
Expand All @@ -199,16 +203,16 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "Layout(" << l->name << ")";
});

inline bool GetStoreRule(Array<PrimExpr>* rule, const Layout& src_layout,
const Layout& dst_layout) {
inline bool GetStoreRule(Array<PrimExpr>* index_rule, Array<PrimExpr>* shape_rule,
const Layout& src_layout, const Layout& dst_layout) {
if (!src_layout.defined() || src_layout.name().empty() || !dst_layout.defined() ||
dst_layout.name().empty()) {
return false;
}
for (size_t i = 0; i < dst_layout.ndim(); ++i) {
const auto& store_axis = dst_layout[i];
const IterVar& store_axis_impl = dst_layout->axes[i];
PrimExpr store(0);
PrimExpr index_store(0);

for (size_t j = 0; j < src_layout.ndim(); ++j) {
const auto& orig_axis = src_layout[j];
Expand All @@ -220,28 +224,63 @@ inline bool GetStoreRule(Array<PrimExpr>* rule, const Layout& src_layout,
if (factor > 0) {
orig_var = orig_var * factor;
}
store = store + orig_var;
index_store = index_store + orig_var;
} else {
store = store + orig_axis_impl->var;
PrimExpr factor(1);
for (size_t k = j + 1; k < src_layout.ndim(); ++k) {
if (LayoutAxis::Get(orig_axis_impl) == LayoutAxis::Get(src_layout->axes[k])) {
factor = factor * src_layout->axes[k]->dom->extent;
}
}
index_store = index_store + orig_axis_impl->var * factor;
}
}
}
if (tir::is_zero(store)) {
if (tir::is_zero(index_store)) {
// Not convertible
return false;
}

PrimExpr shape_store = index_store;
if (store_axis.IsPrimal()) {
const int32_t factor = dst_layout.FactorOf(store_axis);
if (factor > 0) {
store = indexdiv(store, PrimExpr(factor));
shape_store = shapediv(index_store, PrimExpr(factor));
index_store = indexdiv(index_store, PrimExpr(factor));
}
} else {
store = indexmod(store, store_axis_impl->dom->extent);
PrimExpr stride(1);
PrimExpr factor(1);
for (size_t j = i; j < dst_layout.ndim(); ++j) {
if (LayoutAxis::Get(store_axis_impl) == LayoutAxis::Get(dst_layout->axes[j])) {
stride = stride * dst_layout->axes[j]->dom->extent;
if (j > i) {
factor = factor * dst_layout->axes[j]->dom->extent;
}
}
}
shape_store = indexdiv(indexmod(index_store, stride), factor);
index_store = indexdiv(indexmod(index_store, stride), factor);
}

rule->push_back(store);
index_rule->push_back(index_store);
shape_rule->push_back(shape_store);
}

std::stringstream ss;
ss << "index rule for " << src_layout.name() << "-->" << dst_layout.name() << ": [ ";
for (const auto& r : *index_rule) {
ss << r << ", ";
}
ss << "]" << std::endl;

ss << "shape rule for " << src_layout.name() << "-->" << dst_layout.name() << ": [ ";
for (const auto& r : *shape_rule) {
ss << r << ", ";
}
ss << "]" << std::endl;
VLOG(1) << std::endl << ss.str();

return true;
}

Expand All @@ -265,15 +304,15 @@ Array<PrimExpr> BijectiveLayout::ForwardIndex(const Array<PrimExpr>& src_index)
const BijectiveLayoutNode* self = operator->();
ICHECK_EQ(src_index.size(), self->src_layout->axes.size())
<< "Input mismatch with layout " << self->src_layout;
return TransformIndex(src_index, self->src_layout->axes, self->forward_rule);
return TransformIndex(src_index, self->src_layout->axes, self->index_forward_rule);
}

Array<PrimExpr> BijectiveLayout::BackwardIndex(const Array<PrimExpr>& dst_index) const {
ICHECK(defined()) << "Cannot operate on an undefined bijective layout.";
const BijectiveLayoutNode* self = operator->();
ICHECK_EQ(dst_index.size(), self->dst_layout->axes.size())
<< "Output mismatch with layout " << self->dst_layout;
return TransformIndex(dst_index, self->dst_layout->axes, self->backward_rule);
return TransformIndex(dst_index, self->dst_layout->axes, self->index_backward_rule);
}

inline Array<PrimExpr> TransformShape(const Array<PrimExpr>& src_shape,
Expand Down Expand Up @@ -331,19 +370,41 @@ inline Array<PrimExpr> TransformShape(const Array<PrimExpr>& src_shape,
}
}
}

std::stringstream ss;
ss << "shape rule for " << Layout(src_axis).name() << "-->" << Layout(target_axis).name()
<< ": [ ";
for (const auto& r : transform_rule) {
ss << r << ", ";
}
ss << "]" << std::endl;

ss << "shape transform: [ ";
for (const auto& s : src_shape) {
ss << s << ", ";
}
ss << "] --> [ ";
for (const auto& r : result) {
ss << r << ", ";
}
ss << "]" << std::endl;
VLOG(1) << std::endl << ss.str();

return result;
}

Array<PrimExpr> BijectiveLayout::ForwardShape(const Array<PrimExpr>& shape) const {
ICHECK(defined()) << "Cannot operate on an undefined bijective layout.";
const BijectiveLayoutNode* self = operator->();
return TransformShape(shape, self->src_layout->axes, self->dst_layout->axes, self->forward_rule);
return TransformShape(shape, self->src_layout->axes, self->dst_layout->axes,
self->shape_forward_rule);
}

Array<PrimExpr> BijectiveLayout::BackwardShape(const Array<PrimExpr>& shape) const {
ICHECK(defined()) << "Cannot operate on an undefined bijective layout.";
const BijectiveLayoutNode* self = operator->();
return TransformShape(shape, self->dst_layout->axes, self->src_layout->axes, self->backward_rule);
return TransformShape(shape, self->dst_layout->axes, self->src_layout->axes,
self->shape_backward_rule);
}

BijectiveLayout::BijectiveLayout(Layout src_layout, Layout dst_layout) {
Expand All @@ -354,8 +415,9 @@ BijectiveLayout::BijectiveLayout(Layout src_layout, Layout dst_layout) {

// To be consistent with previous behavior, a nullptr layout is created
// when argument is invalid.
if (GetStoreRule(&n->forward_rule, n->src_layout, n->dst_layout)) {
ICHECK(GetStoreRule(&n->backward_rule, n->dst_layout, n->src_layout));
if (GetStoreRule(&n->index_forward_rule, &n->shape_forward_rule, n->src_layout, n->dst_layout)) {
ICHECK(GetStoreRule(&n->index_backward_rule, &n->shape_backward_rule, n->dst_layout,
n->src_layout));
data_ = std::move(n);
}
}
Expand Down
11 changes: 11 additions & 0 deletions src/tir/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,8 @@ PrimExpr operator%(PrimExpr a, PrimExpr b) { return truncmod(a, b); }
// TODO(tqchen): switch to floordiv
PrimExpr indexdiv(PrimExpr a, PrimExpr b, Span span) { return floordiv(a, b, span); }

PrimExpr shapediv(PrimExpr a, PrimExpr b, Span span) { return ceildiv(a, b, span); }

PrimExpr indexmod(PrimExpr a, PrimExpr b, Span span) { return floormod(a, b, span); }

PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span) {
Expand All @@ -380,6 +382,15 @@ PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span) {
return tir::FloorDiv(a, b, span);
}

PrimExpr ceildiv(PrimExpr a, PrimExpr b, Span span) {
ICHECK(a.dtype().is_int() || a.dtype().is_uint()) << a;
ICHECK(b.dtype().is_int() || b.dtype().is_uint()) << b;
BinaryOpMatchTypes(a, b, span);
PrimExpr ret = arith::TryConstFold<tir::FloorDiv>(a + b - 1, b);
if (ret.defined()) return ret;
return tir::FloorDiv(a + b - 1, b, span);
}

PrimExpr floormod(PrimExpr a, PrimExpr b, Span span) {
ICHECK(a.dtype().is_int() || a.dtype().is_uint()) << a;
ICHECK(b.dtype().is_int() || b.dtype().is_uint()) << b;
Expand Down
46 changes: 46 additions & 0 deletions tests/python/relay/test_pass_alter_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,52 @@ def expected():
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)


def test_alter_layout_multi():
"""Test alternating the layout of a conv2d.
The layout of broadcast operators and the weight should be changed accordingly.
"""

def before():
x = relay.var("x", shape=(1, 64, 56, 56))
weight = relay.var("weight")
y = relay.nn.conv2d(x, weight, channels=128, kernel_size=(3, 3), padding=(1, 1))
y = relay.Function(analysis.free_vars(y), y)
return y

def alter_conv2d(attrs, inputs, tinfos, out_type):
data, weight = inputs
new_attrs = dict(attrs)
new_attrs["data_layout"] = "NCHW16c"
new_attrs["kernel_layout"] = "OHWI16i64o2i"
return relay.nn.conv2d(data, weight, **new_attrs)

def expected():
x = relay.var("x", shape=(1, 64, 56, 56))
weight = relay.var("weight", shape=(128, 64, 3, 3))

y = relay.layout_transform(x, "NCHW", "NCHW16c")
w = relay.layout_transform(weight, "OIHW", "OHWI16i64o2i")
y = relay.nn.conv2d(
y,
w,
channels=128,
kernel_size=(3, 3),
padding=(1, 1),
kernel_layout="OHWI16i64o2i",
data_layout="NCHW16c",
)
y = relay.layout_transform(y, "NCHW16c", "NCHW")
y = relay.Function(analysis.free_vars(y), y)
return y

with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before()
a = run_opt_pass(a, [transform.CanonicalizeOps(), transform.AlterOpLayout()])
b = run_opt_pass(expected(), transform.InferType())

assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)


def test_alter_layout_lrn():
"""Test alternating the layout of a conv2d.
The layout of broadcast operators and the weight should be changed accordingly.
Expand Down

0 comments on commit 8d76075

Please sign in to comment.