From ebbda97ce0d5e89e24ff11244787296605acac5c Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 1 Dec 2020 19:54:40 -0800 Subject: [PATCH] fix attrs --- include/tvm/relay/attrs/nn.h | 3 --- python/tvm/relay/op/strategy/generic.py | 8 ++++++-- src/relay/transforms/auto_scheduler_layout_rewrite.cc | 10 ++++++++++ 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 67b719ef1dcdb..f8aa1fc508b63 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -171,9 +171,6 @@ struct Conv2DAttrs : public tvm::AttrsNode { "Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "dimensions respectively. Default to be same as input layout."); - TVM_ATTR_FIELD(auto_scheduler_rewritten_layout) - .set_default("") - .describe("New kernel layout after auto-scheduler's layout rewrite."); // use 0 bits to indicate none. TVM_ATTR_FIELD(out_dtype) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index f746926880cfe..e2ab3cd4f5d46 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -19,7 +19,7 @@ import logging import re -from tvm import topi +from tvm import topi, _ffi from tvm.topi.utils import get_const_int, get_const_float, get_const_tuple, get_float_tuple from tvm.target import generic_func, override_native_generic_func from .. import op as _op @@ -166,6 +166,10 @@ def schedule_bitpack(attrs, outs, target): return topi.generic.schedule_bitpack(outs) +get_auto_scheduler_rewritten_layout = _ffi.get_global_func( + "relay.attrs.get_auto_scheduler_rewritten_layout" +) + # conv2d def wrap_compute_conv2d( topi_compute, @@ -183,7 +187,7 @@ def _compute_conv2d(attrs, inputs, out_type): data_layout = attrs.get_str("data_layout") out_layout = attrs.get_str("out_layout") out_dtype = attrs.out_dtype - auto_scheduler_rewritten_layout = attrs.get_str("auto_scheduler_rewritten_layout") + auto_scheduler_rewritten_layout = get_auto_scheduler_rewritten_layout(attrs) out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype args = [inputs[0], inputs[1], strides, padding, dilation] if has_groups: diff --git a/src/relay/transforms/auto_scheduler_layout_rewrite.cc b/src/relay/transforms/auto_scheduler_layout_rewrite.cc index 259ae8fc2a8e5..19b63151918ec 100644 --- a/src/relay/transforms/auto_scheduler_layout_rewrite.cc +++ b/src/relay/transforms/auto_scheduler_layout_rewrite.cc @@ -146,6 +146,16 @@ Pass AutoSchedulerLayoutRewrite() { TVM_REGISTER_GLOBAL("relay._transform.AutoSchedulerLayoutRewrite") .set_body_typed(AutoSchedulerLayoutRewrite); +TVM_REGISTER_GLOBAL("relay.attrs.get_auto_scheduler_rewritten_layout") + .set_body_typed([](const Attrs& attrs) { + if (attrs->IsInstance()) { + return attrs.as()->auto_scheduler_rewritten_layout; + } else { + LOG(FATAL) << "Cannot get auto_scheduler_rewritten_layout from " << attrs; + } + return std::string(); + }); + } // namespace transform } // namespace relay