Skip to content

Commit

Permalink
fix attrs
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Dec 2, 2020
1 parent 3d492b3 commit ebbda97
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 5 deletions.
3 changes: 0 additions & 3 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,6 @@ struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
"Dimension ordering of output. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Default to be same as input layout.");
TVM_ATTR_FIELD(auto_scheduler_rewritten_layout)
.set_default("")
.describe("New kernel layout after auto-scheduler's layout rewrite.");

// use 0 bits to indicate none.
TVM_ATTR_FIELD(out_dtype)
Expand Down
8 changes: 6 additions & 2 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import logging

import re
from tvm import topi
from tvm import topi, _ffi
from tvm.topi.utils import get_const_int, get_const_float, get_const_tuple, get_float_tuple
from tvm.target import generic_func, override_native_generic_func
from .. import op as _op
Expand Down Expand Up @@ -166,6 +166,10 @@ def schedule_bitpack(attrs, outs, target):
return topi.generic.schedule_bitpack(outs)


get_auto_scheduler_rewritten_layout = _ffi.get_global_func(
"relay.attrs.get_auto_scheduler_rewritten_layout"
)

# conv2d
def wrap_compute_conv2d(
topi_compute,
Expand All @@ -183,7 +187,7 @@ def _compute_conv2d(attrs, inputs, out_type):
data_layout = attrs.get_str("data_layout")
out_layout = attrs.get_str("out_layout")
out_dtype = attrs.out_dtype
auto_scheduler_rewritten_layout = attrs.get_str("auto_scheduler_rewritten_layout")
auto_scheduler_rewritten_layout = get_auto_scheduler_rewritten_layout(attrs)
out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
args = [inputs[0], inputs[1], strides, padding, dilation]
if has_groups:
Expand Down
10 changes: 10 additions & 0 deletions src/relay/transforms/auto_scheduler_layout_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,16 @@ Pass AutoSchedulerLayoutRewrite() {
TVM_REGISTER_GLOBAL("relay._transform.AutoSchedulerLayoutRewrite")
.set_body_typed(AutoSchedulerLayoutRewrite);

TVM_REGISTER_GLOBAL("relay.attrs.get_auto_scheduler_rewritten_layout")
.set_body_typed([](const Attrs& attrs) {
if (attrs->IsInstance<Conv2DAttrs>()) {
return attrs.as<Conv2DAttrs>()->auto_scheduler_rewritten_layout;
} else {
LOG(FATAL) << "Cannot get auto_scheduler_rewritten_layout from " << attrs;
}
return std::string();
});

} // namespace transform

} // namespace relay
Expand Down

0 comments on commit ebbda97

Please sign in to comment.