Skip to content

Commit

Permalink
[Relay][Op] MetaSchedule layout in TypeRel (#11819)
Browse files Browse the repository at this point in the history
Co-authored-by: Junru Shao <junrushao1994@gmail.com>
  • Loading branch information
jinhongyii and junrushao authored Jun 22, 2022
1 parent 5056eb7 commit 98fb955
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 46 deletions.
18 changes: 12 additions & 6 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
tvm::String data_layout;
tvm::String kernel_layout;
tvm::String out_layout;
tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite
tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite
Array<PrimExpr> meta_schedule_original_shape; // The original shape of the weights
DataType out_dtype;

TVM_DECLARE_ATTRS(Conv2DAttrs, "relay.attrs.Conv2DAttrs") {
Expand Down Expand Up @@ -217,7 +218,8 @@ struct Conv2DWinogradAttrs : public tvm::AttrsNode<Conv2DWinogradAttrs> {
tvm::String data_layout;
tvm::String kernel_layout;
tvm::String out_layout;
tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite
tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite
Array<PrimExpr> meta_schedule_original_shape; // The original shape of the weights
DataType out_dtype;

TVM_DECLARE_ATTRS(Conv2DWinogradAttrs, "relay.attrs.Conv2DWinogradAttrs") {
Expand Down Expand Up @@ -308,7 +310,8 @@ struct Conv3DAttrs : public tvm::AttrsNode<Conv3DAttrs> {
tvm::String data_layout;
tvm::String kernel_layout;
tvm::String out_layout;
tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite
tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite
Array<PrimExpr> meta_schedule_original_shape; // The original shape of the weights
DataType out_dtype;

TVM_DECLARE_ATTRS(Conv3DAttrs, "relay.attrs.Conv3DAttrs") {
Expand Down Expand Up @@ -1049,7 +1052,8 @@ struct MatmulAttrs : public tvm::AttrsNode<MatmulAttrs> {
DataType out_dtype;
bool transpose_a;
bool transpose_b;
tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite
tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite
Array<PrimExpr> meta_schedule_original_shape; // The original shape of the weights

TVM_DECLARE_ATTRS(MatmulAttrs, "relay.attrs.MatmulAttrs") {
TVM_ATTR_FIELD(units).describe("Number of hidden units of the dense transformation.");
Expand All @@ -1072,7 +1076,8 @@ struct MatmulAttrs : public tvm::AttrsNode<MatmulAttrs> {
/*! \brief Attributes for dense operator */
struct DenseAttrs : public tvm::AttrsNode<DenseAttrs> {
IndexExpr units;
tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite
tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite
Array<PrimExpr> meta_schedule_original_shape; // The original shape of the weights
DataType out_dtype;

TVM_DECLARE_ATTRS(DenseAttrs, "relay.attrs.DenseAttrs") {
Expand Down Expand Up @@ -1109,7 +1114,8 @@ struct BatchMatmulAttrs : public tvm::AttrsNode<BatchMatmulAttrs> {
DataType out_dtype;
bool transpose_a;
bool transpose_b;
tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite
tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite
Array<PrimExpr> meta_schedule_original_shape; // The original shape of the weights

TVM_DECLARE_ATTRS(BatchMatmulAttrs, "relay.attrs.BatchMatmulAttrs") {
// use 0 bits to indicate none.
Expand Down
60 changes: 38 additions & 22 deletions src/relay/op/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,18 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
Layout kOIHW("OIHW");

const auto* param = attrs.as<Conv2DAttrs>();
DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
if (out_dtype.bits() == 0 && weight != nullptr) {
out_dtype = weight->dtype;
}
}
TensorType meta_schedule_weight{nullptr};
if (param->meta_schedule_original_shape.size() != 0) {
meta_schedule_weight = TensorType(param->meta_schedule_original_shape, out_dtype);
weight = meta_schedule_weight.get();
}
ICHECK(param != nullptr);
const Layout in_layout(param->data_layout);
const Layout kernel_layout(param->kernel_layout);
Expand Down Expand Up @@ -273,27 +285,27 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
weight_dtype = weight->dtype;
}

if (param->auto_scheduler_rewritten_layout.size() == 0) {
// Normal case: assign result to reporter
reporter->Assign(types[1], TensorType(wshape, weight_dtype));
} else {
if (param->auto_scheduler_rewritten_layout.size() != 0) {
// If the layout is rewritten by auto-scheduler,
// we just forcly apply the layout provided by auto-scheduler and
// skip the normal inference logic.
{} // do nothing
} else {
// Normal case: assign result to reporter
reporter->Assign(types[1], TensorType(wshape, weight_dtype));
}
} else {
// use weight to infer the conv shape.
if (weight == nullptr) return false;

Array<PrimExpr> wshape;
if (param->auto_scheduler_rewritten_layout.size() == 0) {
wshape = weight->shape;
} else {
if (param->auto_scheduler_rewritten_layout.size() != 0) {
// works for the default kernel layout "HWIO"
ICHECK_EQ(param->kernel_layout, "HWIO");
wshape = auto_scheduler::GetShapeFromRewrittenLayout(param->auto_scheduler_rewritten_layout,
{"ry", "rx", "rc", "ff"});
} else {
wshape = weight->shape;
}

wshape = trans_kernel_layout.ForwardShape(wshape);
Expand Down Expand Up @@ -357,10 +369,6 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
} else {
oshape.Set(3, dshape_nchw[3]);
}
DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
}
oshape = trans_out_layout.BackwardShape(oshape);
// assign output type
reporter->Assign(types[2], TensorType(oshape, out_dtype));
Expand Down Expand Up @@ -412,6 +420,18 @@ bool Conv3DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,

const auto* param = attrs.as<Conv3DAttrs>();
ICHECK(param != nullptr);
DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
if (out_dtype.bits() == 0 && weight != nullptr) {
out_dtype = weight->dtype;
}
}
TensorType meta_schedule_weight{nullptr};
if (param->meta_schedule_original_shape.size() != 0) {
meta_schedule_weight = TensorType(param->meta_schedule_original_shape, out_dtype);
weight = meta_schedule_weight.get();
}
const Layout in_layout(param->data_layout);
const Layout kernel_layout(param->kernel_layout);

Expand Down Expand Up @@ -450,28 +470,28 @@ bool Conv3DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
weight_dtype = weight->dtype;
}

if (param->auto_scheduler_rewritten_layout.size() == 0) {
// Normal case: assign result to reporter
reporter->Assign(types[1], TensorType(wshape, weight_dtype));
} else {
if (param->auto_scheduler_rewritten_layout.size() != 0) {
// If the layout is rewritten by auto-scheduler,
// we just forcly apply the layout provided by auto-scheduler and
// skip the normal inference logic.
{} // do nothing
} else {
// Normal case: assign result to reporter
reporter->Assign(types[1], TensorType(wshape, weight_dtype));
}

} else {
// use weight to infer the conv shape.
if (weight == nullptr) return false;

Array<PrimExpr> wshape;
if (param->auto_scheduler_rewritten_layout.size() == 0) {
wshape = weight->shape;
} else {
if (param->auto_scheduler_rewritten_layout.size() != 0) {
// works for the default kernel layout "DHWIO"
ICHECK_EQ(param->kernel_layout, "DHWIO");
wshape = auto_scheduler::GetShapeFromRewrittenLayout(param->auto_scheduler_rewritten_layout,
{"rd", "rh", "rw", "rc", "cc"});
} else {
wshape = weight->shape;
}

wshape = trans_kernel_layout.ForwardShape(wshape);
Expand Down Expand Up @@ -521,10 +541,6 @@ bool Conv3DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
} else {
oshape.Set(4, dshape_ncdhw[4]);
}
DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
}
oshape = trans_out_layout.BackwardShape(oshape);
// assign output type
reporter->Assign(types[2], TensorType(oshape, out_dtype));
Expand Down
54 changes: 36 additions & 18 deletions src/relay/op/nn/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ bool MatmulRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,

const AttrType* param = attrs.as<AttrType>();
ICHECK(param != nullptr);
TensorType meta_schedule_tensor_b{nullptr};
if (param->meta_schedule_original_shape.size() > 0) {
meta_schedule_tensor_b = TensorType(param->meta_schedule_original_shape,
tensor_b == nullptr ? tensor_a->dtype : tensor_b->dtype);
tensor_b = meta_schedule_tensor_b.get();
}
// Default set to dense layout
bool transpose_a = false;
bool transpose_b = true;
Expand All @@ -73,14 +79,14 @@ bool MatmulRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
// data dtype as the tensor_b dtype. However if tensor_b dtype is explicitly
// present we will use that.
auto tensor_b_dtype = (tensor_b == nullptr ? tensor_a->dtype : tensor_b->dtype);
if (param->auto_scheduler_rewritten_layout.size() == 0) {
// Normal case: assign result to reporter
reporter->Assign(types[1], TensorType(wshape, tensor_b_dtype));
} else {
// If the layout is rewritten by auto-scheduler,
// we just forcly apply the layout provided by auto-scheduler and
if (param->auto_scheduler_rewritten_layout.size() != 0) {
// If the layout is rewritten by auto-scheduler or meta-schedule,
// we just forcefully apply the layout provided by auto-scheduler and
// skip the normal inference logic.
{} // do nothing
} else {
// Normal case: assign result to reporter
reporter->Assign(types[1], TensorType(wshape, tensor_b_dtype));
}
oshape.Set((oshape.size() - 1), param->units);
} else {
Expand All @@ -103,7 +109,7 @@ bool MatmulRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
<< "MatmulRel: input dimension doesn't match,"
<< " tensor_a shape=" << tensor_a->shape << ", tensor_b shape=" << tensor_b->shape;
}
oshape.Set((oshape.size() - 1), transpose_b ? wshape[0] : wshape[1]);
oshape.Set(oshape.size() - 1, transpose_b ? wshape[0] : wshape[1]);
}
}

Expand All @@ -125,16 +131,32 @@ bool BatchMatmulRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
if (x == nullptr || y == nullptr) return false;

const AttrType* param = attrs.as<AttrType>();
DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = x->dtype;
if (x->dtype.bits() == 0) {
out_dtype = y->dtype;
}
}
TensorType meta_schedule_y{nullptr};
if (param->meta_schedule_original_shape.size() != 0) {
meta_schedule_y = TensorType(param->meta_schedule_original_shape, out_dtype);
y = meta_schedule_y.get();
}
ICHECK(param != nullptr);
bool transpose_a = param->transpose_a;
bool transpose_b = param->transpose_b;
const Array<PrimExpr>& y_shape =
param->auto_scheduler_rewritten_layout.size() == 0
? y->shape
: auto_scheduler::GetShapeFromRewrittenLayout(
param->auto_scheduler_rewritten_layout,
transpose_b ? tvm::runtime::Array<tvm::runtime::String>({"b", "j", "k"})
: tvm::runtime::Array<tvm::runtime::String>({"b", "k", "j"}));
Array<PrimExpr> y_shape{nullptr};
if (param->auto_scheduler_rewritten_layout.size() != 0) {
y_shape = auto_scheduler::GetShapeFromRewrittenLayout(
param->auto_scheduler_rewritten_layout,
transpose_b ? tvm::runtime::Array<tvm::runtime::String>({"b", "j", "k"})
: tvm::runtime::Array<tvm::runtime::String>({"b", "k", "j"}));
} else if (param->meta_schedule_original_shape.size() != 0) {
y_shape = param->meta_schedule_original_shape;
} else {
y_shape = y->shape;
}
ICHECK(x->shape.size() == 3 && y_shape.size() == 3);
const PrimExpr& xb = x->shape[0];
const PrimExpr& xi = x->shape[transpose_a ? 2 : 1];
Expand All @@ -158,10 +180,6 @@ bool BatchMatmulRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
<< " x shape=" << x->shape << ", y shape=" << y_shape;
}

DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = x->dtype;
}
// assign output type
const auto& out_b =
xb->IsInstance<tir::AnyNode>() || yb->IsInstance<tir::AnyNode>() ? tir::Any() : max(xb, yb);
Expand Down
1 change: 1 addition & 0 deletions src/relay/transforms/fold_explicit_padding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ class SimplifyExplicitPad {

T* new_attrs = const_cast<T*>(attrs.template as<T>());
new_attrs->auto_scheduler_rewritten_layout = old_attrs->auto_scheduler_rewritten_layout;
new_attrs->meta_schedule_original_shape = old_attrs->meta_schedule_original_shape;
return attrs;
}

Expand Down

0 comments on commit 98fb955

Please sign in to comment.