Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] add support for multi-blocking layout and their transformation #9996

Merged
merged 8 commits into from
Feb 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions include/tvm/tir/data_layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,9 +295,13 @@ class BijectiveLayoutNode : public Object {
/*! \brief Describes how source axes can be mapped to the destination axes,
* e.g., [i0 / 16, i1, i0 % 16] can describe NC -> NC16n
*/
Array<PrimExpr> forward_rule;
Array<PrimExpr> index_forward_rule;
/*! \brief Describes how destination axes can be mapped to the source axes */
Array<PrimExpr> backward_rule;
Array<PrimExpr> index_backward_rule;
/*! \brief Describes how source shapes can be mapped to the destination shapes */
Array<PrimExpr> shape_forward_rule;
/*! \brief Describes how destination shapes can be mapped to the source shapes */
Array<PrimExpr> shape_backward_rule;

/*! \brief The source layout */
Layout src_layout;
Expand All @@ -307,8 +311,10 @@ class BijectiveLayoutNode : public Object {
void VisitAttrs(AttrVisitor* v) {
v->Visit("src_layout", &src_layout);
v->Visit("dst_layout", &dst_layout);
v->Visit("forward_rule", &forward_rule);
v->Visit("backward_rule", &backward_rule);
v->Visit("index_forward_rule", &index_forward_rule);
v->Visit("index_backward_rule", &index_backward_rule);
v->Visit("shape_forward_rule", &shape_forward_rule);
v->Visit("shape_backward_rule", &shape_backward_rule);
}

static constexpr const char* _type_key = "tir.BijectiveLayout";
Expand Down
27 changes: 27 additions & 0 deletions include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,22 @@ TVM_DLL PrimExpr truncmod(PrimExpr a, PrimExpr b, Span span = Span());
* index types(int32, int64) when possible.
*/
TVM_DLL PrimExpr indexdiv(PrimExpr a, PrimExpr b, Span span = Span());
/*!
* \brief compute ceil(a / b) where a and b are non-negative.
*
* Use this function for shape split calculation.
*
* This function might take advantage of the fact
* that a and b are non-negative.
*
* \param a left operand
* \param b right operand
* \param span The location of this operation in the source.
* \return The result expression.
* \note this function does eager constant folding for
* shape types(int32, int64) when possible.
*/
TVM_DLL PrimExpr shapediv(PrimExpr a, PrimExpr b, Span span = Span());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the name shapediv could be confusing given indexdiv used the other case, how about nonneg_ceildiv?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's a kind of symmetry, as indexdiv is an alias of floordiv to prevent access out-of-boundary, and shapediv is an alias of ceildiv to prevent the shrink of a Tensor.
If this is confusing, I prefer to remove indexdiv and shapediv since they are just aliases of floordiv and ceildiv now, or we can keep them and add check codes for non-negative then change their names to nonneg_floordiv/ceildiv.

/*!
* \brief compute the remainder floor(a / b) where a and b are non-negative.
*
Expand All @@ -521,6 +537,17 @@ TVM_DLL PrimExpr indexmod(PrimExpr a, PrimExpr b, Span span = Span());
* index types(int32, int64) when possible.
*/
TVM_DLL PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span = Span());
/*!
* \brief compute ceil(a / b)
*
* \param a left operand
* \param b right operand
* \param span The location of this operation in the source.
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL PrimExpr ceildiv(PrimExpr a, PrimExpr b, Span span = Span());
/*!
* \brief compute the remainder of floordiv
*
Expand Down
6 changes: 5 additions & 1 deletion include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -1608,7 +1608,11 @@ inline Tensor layout_transform(const Tensor& src, const std::string& src_layout,
[&](const Array<Var>& dst_indices) {
Array<PrimExpr> dst_indices_expr(dst_indices.begin(), dst_indices.end());
Array<PrimExpr> src_indices = layout_converter.BackwardIndex(dst_indices_expr);
return src(src_indices);
PrimExpr in_range = PrimExpr(1) > PrimExpr(0); // init with dtype=bool and value=true
for (size_t i = 0; i < src.ndim(); ++i) {
in_range = in_range && (src_indices[i] < src->shape[i]);
}
return if_then_else(in_range, src(src_indices), tvm::cast(src->dtype, PrimExpr(0)));
masahi marked this conversation as resolved.
Show resolved Hide resolved
},
name, tag);
}
Expand Down
19 changes: 3 additions & 16 deletions python/tvm/topi/x86/conv2d_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,25 +164,12 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape)
out_channel, channel_multiplier, kh, kw = get_const_tuple(kernel_tensor.shape)
ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
n_elems = 4

# convert kernel data layout from 4D to 7D
data_expr, kernel_expr = inputs
kernel_IHWO = relay.transpose(kernel_expr, axes=(1, 2, 3, 0))
kernel_IHWOo = relay.reshape(kernel_IHWO, (in_channel, kh, kw, out_channel // oc_bn, oc_bn))
kernel_OHWoI = relay.transpose(kernel_IHWOo, axes=(3, 1, 2, 4, 0))
kernel_OHWoIi = relay.reshape(
kernel_OHWoI, (out_channel // oc_bn, kh, kw, oc_bn, in_channel // ic_bn, ic_bn)
)
kernel_OHWoIie = relay.reshape(
kernel_OHWoIi,
(out_channel // oc_bn, kh, kw, oc_bn, in_channel // ic_bn, ic_bn // n_elems, n_elems),
)
kernel_OIHWioe = relay.transpose(kernel_OHWoIie, axes=(0, 4, 1, 2, 5, 3, 6))

# update new attrs
n_elems = 4
new_attrs["channels"] = out_channel
new_attrs["data_layout"] = "NCHW%dc" % ic_bn
new_attrs["kernel_layout"] = "OIHW{:n}i{:n}o{:n}i".format(ic_bn // n_elems, oc_bn, n_elems)
masahi marked this conversation as resolved.
Show resolved Hide resolved
new_attrs["out_layout"] = "NCHW%dc" % oc_bn

# Store altered operator's config.
Expand All @@ -208,7 +195,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
)
dispatch_ctx.update(target, new_workload, cfg)

return relay.nn.contrib_conv2d_nchwc(data_expr, kernel_OIHWioe, **new_attrs)
return relay.nn.contrib_conv2d_nchwc(*inputs, **new_attrs)

if topi_tmpl == "depthwise_conv2d_NCHWc.x86":
if data_layout == "NCHW" and kernel_layout == "OIHW":
Expand Down
104 changes: 83 additions & 21 deletions src/tir/ir/data_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ Layout::Layout(const std::string& name) { // NOLINT(*)
ICHECK_EQ(axis_str.size(), 1);
char axis = axis_str[0];
ICHECK((axis >= 'a' && axis <= 'z') || (axis >= 'A' && axis <= 'Z'));
ICHECK(!exist_axis[axis]) << "Invalid layout " << name << ": duplicate axis " << axis;
exist_axis[axis] = true;
}
for (const IterVar& v : node->axes) {
Expand Down Expand Up @@ -182,15 +181,20 @@ Layout Layout::Split(const LayoutAxis& axis, size_t target_pos, int32_t factor)
int32_t Layout::FactorOf(const LayoutAxis& axis) const {
if (!defined()) return -1;
const LayoutAxis& sub = axis.ToSubordinate();
if (!this->defined()) return -1;

int32_t factor = 1;
bool has_sub = false;
for (const IterVar& itvar : operator->()->axes) {
if (sub == LayoutAxis::Get(itvar)) {
const auto* factor = itvar->dom->extent.as<IntImmNode>();
ICHECK(factor);
return factor->value;
has_sub = true;
int32_t val = itvar->dom->extent.as<IntImmNode>()->value;
ICHECK(val);
factor *= val;
}
}
return -1;
factor = has_sub ? factor : -1;

return factor;
}

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
Expand All @@ -199,16 +203,16 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "Layout(" << l->name << ")";
});

inline bool GetStoreRule(Array<PrimExpr>* rule, const Layout& src_layout,
const Layout& dst_layout) {
inline bool GetStoreRule(Array<PrimExpr>* index_rule, Array<PrimExpr>* shape_rule,
const Layout& src_layout, const Layout& dst_layout) {
if (!src_layout.defined() || src_layout.name().empty() || !dst_layout.defined() ||
dst_layout.name().empty()) {
return false;
}
for (size_t i = 0; i < dst_layout.ndim(); ++i) {
const auto& store_axis = dst_layout[i];
const IterVar& store_axis_impl = dst_layout->axes[i];
PrimExpr store(0);
PrimExpr index_store(0);

for (size_t j = 0; j < src_layout.ndim(); ++j) {
const auto& orig_axis = src_layout[j];
Expand All @@ -220,28 +224,63 @@ inline bool GetStoreRule(Array<PrimExpr>* rule, const Layout& src_layout,
if (factor > 0) {
orig_var = orig_var * factor;
}
store = store + orig_var;
index_store = index_store + orig_var;
} else {
store = store + orig_axis_impl->var;
PrimExpr factor(1);
for (size_t k = j + 1; k < src_layout.ndim(); ++k) {
if (LayoutAxis::Get(orig_axis_impl) == LayoutAxis::Get(src_layout->axes[k])) {
factor = factor * src_layout->axes[k]->dom->extent;
}
}
index_store = index_store + orig_axis_impl->var * factor;
}
}
}
if (tir::is_zero(store)) {
if (tir::is_zero(index_store)) {
// Not convertible
return false;
}

PrimExpr shape_store = index_store;
if (store_axis.IsPrimal()) {
const int32_t factor = dst_layout.FactorOf(store_axis);
if (factor > 0) {
store = indexdiv(store, PrimExpr(factor));
shape_store = shapediv(index_store, PrimExpr(factor));
index_store = indexdiv(index_store, PrimExpr(factor));
}
} else {
store = indexmod(store, store_axis_impl->dom->extent);
PrimExpr stride(1);
PrimExpr factor(1);
for (size_t j = i; j < dst_layout.ndim(); ++j) {
if (LayoutAxis::Get(store_axis_impl) == LayoutAxis::Get(dst_layout->axes[j])) {
stride = stride * dst_layout->axes[j]->dom->extent;
if (j > i) {
factor = factor * dst_layout->axes[j]->dom->extent;
}
}
}
shape_store = indexdiv(indexmod(index_store, stride), factor);
index_store = indexdiv(indexmod(index_store, stride), factor);
}

rule->push_back(store);
index_rule->push_back(index_store);
shape_rule->push_back(shape_store);
}

std::stringstream ss;
ss << "index rule for " << src_layout.name() << "-->" << dst_layout.name() << ": [ ";
for (const auto& r : *index_rule) {
ss << r << ", ";
}
ss << "]" << std::endl;

ss << "shape rule for " << src_layout.name() << "-->" << dst_layout.name() << ": [ ";
for (const auto& r : *shape_rule) {
ss << r << ", ";
}
ss << "]" << std::endl;
VLOG(1) << std::endl << ss.str();

return true;
}

Expand All @@ -265,15 +304,15 @@ Array<PrimExpr> BijectiveLayout::ForwardIndex(const Array<PrimExpr>& src_index)
const BijectiveLayoutNode* self = operator->();
ICHECK_EQ(src_index.size(), self->src_layout->axes.size())
<< "Input mismatch with layout " << self->src_layout;
return TransformIndex(src_index, self->src_layout->axes, self->forward_rule);
return TransformIndex(src_index, self->src_layout->axes, self->index_forward_rule);
}

Array<PrimExpr> BijectiveLayout::BackwardIndex(const Array<PrimExpr>& dst_index) const {
ICHECK(defined()) << "Cannot operate on an undefined bijective layout.";
const BijectiveLayoutNode* self = operator->();
ICHECK_EQ(dst_index.size(), self->dst_layout->axes.size())
<< "Output mismatch with layout " << self->dst_layout;
return TransformIndex(dst_index, self->dst_layout->axes, self->backward_rule);
return TransformIndex(dst_index, self->dst_layout->axes, self->index_backward_rule);
}

inline Array<PrimExpr> TransformShape(const Array<PrimExpr>& src_shape,
Expand Down Expand Up @@ -331,19 +370,41 @@ inline Array<PrimExpr> TransformShape(const Array<PrimExpr>& src_shape,
}
}
}

std::stringstream ss;
ss << "shape rule for " << Layout(src_axis).name() << "-->" << Layout(target_axis).name()
<< ": [ ";
for (const auto& r : transform_rule) {
ss << r << ", ";
}
ss << "]" << std::endl;

ss << "shape transform: [ ";
for (const auto& s : src_shape) {
ss << s << ", ";
}
ss << "] --> [ ";
for (const auto& r : result) {
ss << r << ", ";
}
ss << "]" << std::endl;
VLOG(1) << std::endl << ss.str();

return result;
}

Array<PrimExpr> BijectiveLayout::ForwardShape(const Array<PrimExpr>& shape) const {
ICHECK(defined()) << "Cannot operate on an undefined bijective layout.";
const BijectiveLayoutNode* self = operator->();
return TransformShape(shape, self->src_layout->axes, self->dst_layout->axes, self->forward_rule);
return TransformShape(shape, self->src_layout->axes, self->dst_layout->axes,
self->shape_forward_rule);
}

Array<PrimExpr> BijectiveLayout::BackwardShape(const Array<PrimExpr>& shape) const {
ICHECK(defined()) << "Cannot operate on an undefined bijective layout.";
const BijectiveLayoutNode* self = operator->();
return TransformShape(shape, self->dst_layout->axes, self->src_layout->axes, self->backward_rule);
return TransformShape(shape, self->dst_layout->axes, self->src_layout->axes,
self->shape_backward_rule);
}

BijectiveLayout::BijectiveLayout(Layout src_layout, Layout dst_layout) {
Expand All @@ -354,8 +415,9 @@ BijectiveLayout::BijectiveLayout(Layout src_layout, Layout dst_layout) {

// To be consistent with previous behavior, a nullptr layout is created
// when argument is invalid.
if (GetStoreRule(&n->forward_rule, n->src_layout, n->dst_layout)) {
ICHECK(GetStoreRule(&n->backward_rule, n->dst_layout, n->src_layout));
if (GetStoreRule(&n->index_forward_rule, &n->shape_forward_rule, n->src_layout, n->dst_layout)) {
ICHECK(GetStoreRule(&n->index_backward_rule, &n->shape_backward_rule, n->dst_layout,
n->src_layout));
data_ = std::move(n);
}
}
Expand Down
11 changes: 11 additions & 0 deletions src/tir/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,8 @@ PrimExpr operator%(PrimExpr a, PrimExpr b) { return truncmod(a, b); }
// TODO(tqchen): switch to floordiv
PrimExpr indexdiv(PrimExpr a, PrimExpr b, Span span) { return floordiv(a, b, span); }

PrimExpr shapediv(PrimExpr a, PrimExpr b, Span span) { return ceildiv(a, b, span); }

PrimExpr indexmod(PrimExpr a, PrimExpr b, Span span) { return floormod(a, b, span); }

PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span) {
Expand All @@ -380,6 +382,15 @@ PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span) {
return tir::FloorDiv(a, b, span);
}

PrimExpr ceildiv(PrimExpr a, PrimExpr b, Span span) {
ICHECK(a.dtype().is_int() || a.dtype().is_uint()) << a;
ICHECK(b.dtype().is_int() || b.dtype().is_uint()) << b;
BinaryOpMatchTypes(a, b, span);
PrimExpr ret = arith::TryConstFold<tir::FloorDiv>(a + b - 1, b);
if (ret.defined()) return ret;
return tir::FloorDiv(a + b - 1, b, span);
}

PrimExpr floormod(PrimExpr a, PrimExpr b, Span span) {
ICHECK(a.dtype().is_int() || a.dtype().is_uint()) << a;
ICHECK(b.dtype().is_int() || b.dtype().is_uint()) << b;
Expand Down
46 changes: 46 additions & 0 deletions tests/python/relay/test_pass_alter_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,52 @@ def expected():
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)


def test_alter_layout_multi():
"""Test alternating the layout of a conv2d.
The layout of broadcast operators and the weight should be changed accordingly.
"""

def before():
x = relay.var("x", shape=(1, 64, 56, 56))
weight = relay.var("weight")
y = relay.nn.conv2d(x, weight, channels=128, kernel_size=(3, 3), padding=(1, 1))
y = relay.Function(analysis.free_vars(y), y)
return y

def alter_conv2d(attrs, inputs, tinfos, out_type):
data, weight = inputs
new_attrs = dict(attrs)
new_attrs["data_layout"] = "NCHW16c"
new_attrs["kernel_layout"] = "OHWI16i64o2i"
return relay.nn.conv2d(data, weight, **new_attrs)

def expected():
x = relay.var("x", shape=(1, 64, 56, 56))
weight = relay.var("weight", shape=(128, 64, 3, 3))

y = relay.layout_transform(x, "NCHW", "NCHW16c")
w = relay.layout_transform(weight, "OIHW", "OHWI16i64o2i")
y = relay.nn.conv2d(
y,
w,
channels=128,
kernel_size=(3, 3),
padding=(1, 1),
kernel_layout="OHWI16i64o2i",
data_layout="NCHW16c",
)
y = relay.layout_transform(y, "NCHW16c", "NCHW")
y = relay.Function(analysis.free_vars(y), y)
return y

with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before()
a = run_opt_pass(a, [transform.CanonicalizeOps(), transform.AlterOpLayout()])
b = run_opt_pass(expected(), transform.InferType())

assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)


def test_alter_layout_lrn():
"""Test alternating the layout of a conv2d.
The layout of broadcast operators and the weight should be changed accordingly.
Expand Down