diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index 2507262c087f..28d202cb50a9 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -67,8 +67,6 @@ class BufferNode : public Object { // Meta data /*! \brief optional name of the buffer */ String name; - /*! \brief storage scope of the buffer, if other than global */ - String scope; /*! \brief Alignment requirement of data pointer in bytes. */ int data_alignment; /*! @@ -93,7 +91,6 @@ class BufferNode : public Object { v->Visit("strides", &strides); v->Visit("elem_offset", &elem_offset); v->Visit("name", &name); - v->Visit("scope", &scope); v->Visit("data_alignment", &data_alignment); v->Visit("offset_factor", &offset_factor); v->Visit("buffer_type", &buffer_type); @@ -105,7 +102,7 @@ class BufferNode : public Object { // in its semantics, skip name as name is not important. return equal.DefEqual(data, other->data) && equal(dtype, other->dtype) && equal.DefEqual(shape, other->shape) && equal.DefEqual(strides, other->strides) && - equal.DefEqual(elem_offset, other->elem_offset) && equal(scope, other->scope) && + equal.DefEqual(elem_offset, other->elem_offset) && equal(data_alignment, other->data_alignment) && equal(buffer_type, other->buffer_type); } @@ -115,7 +112,6 @@ class BufferNode : public Object { hash_reduce.DefHash(shape); hash_reduce.DefHash(strides); hash_reduce.DefHash(elem_offset); - hash_reduce(scope); hash_reduce(data_alignment); hash_reduce(buffer_type); } @@ -141,8 +137,8 @@ class Buffer : public ObjectRef { // User can specify data_alignment and offset_factor to be 0 // A default value will be picked. TVM_DLL Buffer(Var ptr, DataType dtype, Array shape, Array strides, - PrimExpr elem_offset, String name, String scope, int data_alignment, - int offset_factor, BufferType buffer_type, Span span = Span()); + PrimExpr elem_offset, String name, int data_alignment, int offset_factor, + BufferType buffer_type, Span span = Span()); /*! * \brief Return a new buffer that is equivalent with current one @@ -182,6 +178,11 @@ class Buffer : public ObjectRef { */ TVM_DLL Stmt vstore(Array begin, PrimExpr value) const; + /*! + * \brief Return the storage scope associated with this buffer. + */ + TVM_DLL String scope() const; + TVM_DEFINE_OBJECT_REF_METHODS(Buffer, ObjectRef, BufferNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferNode); }; diff --git a/include/tvm/topi/detail/extern.h b/include/tvm/topi/detail/extern.h index caca1e85e520..2561f8d1ca27 100644 --- a/include/tvm/topi/detail/extern.h +++ b/include/tvm/topi/detail/extern.h @@ -48,7 +48,7 @@ using namespace tvm::te; inline Buffer DeclExternBuffer(Array shape, DataType dtype, std::string name) { auto data = var(name, DataType::Handle()); auto elem_offset = PrimExpr(); - return Buffer(data, dtype, shape, Array(), elem_offset, name, "", -1, 0, kDefault); + return Buffer(data, dtype, shape, Array(), elem_offset, name, -1, 0, kDefault); } /*! diff --git a/python/tvm/script/special_stmt.py b/python/tvm/script/special_stmt.py index e40bc2fda6eb..24b2f3af9ab0 100644 --- a/python/tvm/script/special_stmt.py +++ b/python/tvm/script/special_stmt.py @@ -463,7 +463,7 @@ def match_buffer_region( data=None, strides=strides, elem_offset=elem_offset, - scope=buffer_region.buffer.scope, + scope=buffer_region.buffer.scope(), data_alignment=align, offset_factor=offset_factor, span=span, diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py index 086d93f49a2b..b445bcb25005 100644 --- a/python/tvm/tir/buffer.py +++ b/python/tvm/tir/buffer.py @@ -134,6 +134,15 @@ def vstore(self, begin, value): begin = (begin,) if isinstance(begin, (int, PrimExpr)) else begin return _ffi_api.BufferVStore(self, begin, value) # type: ignore + def scope(self): + """Return the storage scope associated with this buffer. + Returns + ------- + scope : str + The storage scope associated with this buffer. + """ + return _ffi_api.BufferStorageScope(self) # type: ignore + def decl_buffer( shape, @@ -260,7 +269,6 @@ def decl_buffer( strides, elem_offset, name, - scope, data_alignment, offset_factor, buffer_type, diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 50f00140df9b..b0434049f60f 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -88,7 +88,7 @@ tir::Buffer BufferWithOffsetAlignment(Array shape, DataType dtype, std elem_offset = PrimExpr(); } - return tir::Buffer(data, dtype, shape, Array(), elem_offset, name, "", data_alignment, + return tir::Buffer(data, dtype, shape, Array(), elem_offset, name, data_alignment, offset_factor, buffer_type); } diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index 0fefb0515e49..0f3b89932b68 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -204,8 +204,8 @@ Doc TIRTextPrinter::BufferNode2Doc(const BufferNode* buf, Doc doc) { if (!is_zero(buf->elem_offset)) { doc << ", elem_offset=" << Print(buf->elem_offset); } - if (buf->scope != "global") { - doc << ", scope=" << Doc::StrLiteral(buf->scope); + if (GetRef(buf).scope() != "global") { + doc << ", scope=" << Doc::StrLiteral(GetRef(buf).scope()); } if (buf->data_alignment != 128) { doc << ", align=" << buf->data_alignment; diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index e855712617ca..01f79bd0c750 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -302,8 +302,8 @@ Doc TVMScriptPrinter::AllocBufferDeclaration(const Buffer& buf) { } else { doc << ", elem_offset=" << Print(buf->elem_offset); } - if (buf->scope != "global") { - doc << ", scope=" << Doc::StrLiteral(buf->scope); + if (buf.scope() != "global") { + doc << ", scope=" << Doc::StrLiteral(buf.scope()); } if (buf->data_alignment != -1) { doc << ", align=" << buf->data_alignment; diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 435870a5e5cc..335ff19dd775 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -51,7 +51,7 @@ Buffer decl_buffer(Array shape, DataType dtype, String name, String st Span span) { DataType storage_dtype = (dtype == DataType::Bool() ? DataType::Int(8) : dtype); return Buffer(Var(name, PointerType(PrimType(storage_dtype), storage_scope), span), dtype, shape, - Array(), PrimExpr(), name, "", 0, 0, kDefault, span); + Array(), PrimExpr(), name, 0, 0, kDefault, span); } // Split the given expression w.r.t the add operator @@ -319,6 +319,15 @@ Stmt Buffer::vstore(Array begin, PrimExpr value) const { } } +String Buffer::scope() const { + const auto* ptr_type = (*this)->data->type_annotation.as(); + ICHECK(ptr_type) << "Buffer variable is not of pointer type"; + if (ptr_type->storage_scope.empty()) { + return "global"; + } + return ptr_type->storage_scope; +} + Buffer Buffer::MakeStrideView() const { if ((*this)->strides.size() != 0) return *this; if ((*this)->shape.size() == 0) return *this; @@ -358,7 +367,7 @@ Buffer Buffer::MakeSlice(Array begins, Array extents) const return MakeStrideView().MakeSlice(begins, extents); } } - return Buffer(n->data, n->dtype, extents, strides, elem_offset, n->name + "_slice", n->scope, + return Buffer(n->data, n->dtype, extents, strides, elem_offset, n->name + "_slice", n->data_alignment, 0, n->buffer_type); } @@ -391,8 +400,8 @@ PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lane } Buffer::Buffer(Var data, DataType dtype, Array shape, Array strides, - PrimExpr elem_offset, String name, String scope, int data_alignment, - int offset_factor, BufferType buffer_type, Span span) { + PrimExpr elem_offset, String name, int data_alignment, int offset_factor, + BufferType buffer_type, Span span) { DataType storage_dtype = dtype; // specially handle bool if (storage_dtype == DataType::Bool()) { @@ -409,10 +418,6 @@ Buffer::Buffer(Var data, DataType dtype, Array shape, Array n->shape = std::move(shape); n->strides = std::move(strides); n->name = std::move(name); - if (scope.length() == 0) { - scope = "global"; - } - n->scope = std::move(scope); if (!elem_offset.defined()) { elem_offset = make_const(n->DefaultIndexType(), 0); } @@ -444,11 +449,11 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_NODE_TYPE(BufferNode); TVM_REGISTER_GLOBAL("tir.Buffer").set_body([](TVMArgs args, TVMRetValue* ret) { - ICHECK_EQ(args.size(), 11); - auto buffer_type = args[9].operator String(); + ICHECK_EQ(args.size(), 10); + auto buffer_type = args[8].operator String(); BufferType type = (buffer_type == "auto_broadcast") ? kAutoBroadcast : kDefault; - *ret = Buffer(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], args[8], - type, args[10]); + *ret = + Buffer(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], type, args[9]); }); TVM_REGISTER_GLOBAL("tir.BufferAccessPtr").set_body_method(&Buffer::access_ptr); @@ -457,5 +462,7 @@ TVM_REGISTER_GLOBAL("tir.BufferVLoad").set_body_method(&Buffer::vload); TVM_REGISTER_GLOBAL("tir.BufferVStore").set_body_method(&Buffer::vstore); +TVM_REGISTER_GLOBAL("tir.BufferStorageScope").set_body_method(&Buffer::scope); + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc index ca61dfea2768..906f5aaabe08 100644 --- a/src/tir/schedule/state.cc +++ b/src/tir/schedule/state.cc @@ -43,7 +43,7 @@ Array AnalyzeRegionUpperBound(const BufferRegion& region, AsIntSet(LoopDomainOfSRefTreePath( /*low_inclusive=*/dom_low_inclusive, /*high_exclusive=*/dom_high_exclusive, - /*extra_relax_scope=*/runtime::StorageScope::Create(region->buffer->scope)))); + /*extra_relax_scope=*/runtime::StorageScope::Create(region->buffer.scope())))); } /*! @@ -67,7 +67,7 @@ Array AnalyzeRegionLowerBound(const BlockRealize& realize, LoopDomainOfSRefTreePath( /*low_inclusive=*/dom_low_inclusive, /*high_exclusive=*/dom_high_exclusive, - /*extra_relax_scope=*/runtime::StorageScope::Create(region->buffer->scope)), + /*extra_relax_scope=*/runtime::StorageScope::Create(region->buffer.scope())), /*predicate=*/realize->predicate, /*analyzer=*/analyzer)) { return result.value(); } diff --git a/src/tir/transforms/arg_binder.cc b/src/tir/transforms/arg_binder.cc index 9cd29357f8c7..293c990d2745 100644 --- a/src/tir/transforms/arg_binder.cc +++ b/src/tir/transforms/arg_binder.cc @@ -88,7 +88,7 @@ void ArgBinder::BindArray(const Array& arg, const Array& val void ArgBinder::BindBuffer(const Buffer& arg, const Buffer& value, const std::string& arg_name, bool fuzzy_match) { - ICHECK_EQ(arg->scope, value->scope) << "Argument " << arg_name << " Buffer bind scope mismatch"; + ICHECK_EQ(arg.scope(), value.scope()) << "Argument " << arg_name << " Buffer bind scope mismatch"; ICHECK_EQ(arg->dtype, value->dtype) << "Argument " << arg_name << " Buffer bind data type mismatch"; if (value->data_alignment % arg->data_alignment != 0) { diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc index 7a8789457923..76845cbebd2a 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/bf16_legalize.cc @@ -323,8 +323,8 @@ class BF16LowerRewriter : public StmtExprMutator { DataType dtype = DataType::UInt(16, oldbuf->dtype.lanes()); Var buffer_var = Var(oldbuf->data->name_hint, PointerType(PrimType(dtype))); auto newbuf = Buffer(buffer_var, dtype, oldbuf->shape, oldbuf->strides, oldbuf->elem_offset, - oldbuf->name, oldbuf->scope, oldbuf->data_alignment, - oldbuf->offset_factor, oldbuf->buffer_type); + oldbuf->name, oldbuf->data_alignment, oldbuf->offset_factor, + oldbuf->buffer_type); buffer_remap_[oldbuf] = newbuf; var_remap_[oldbuf->data] = buffer_var; changes.emplace_back(itr.first, newbuf); diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/tir/transforms/compact_buffer_region.cc index edbafe27cf13..f69a9e54afa4 100644 --- a/src/tir/transforms/compact_buffer_region.cc +++ b/src/tir/transforms/compact_buffer_region.cc @@ -203,7 +203,7 @@ class BufferAccessRegionCollector : public StmtExprVisitor { std::unordered_map dom_map; for (const ForNode* loop : ancestor_loops_) { const VarNode* loop_var = loop->loop_var.get(); - if (NeedRelaxThread(GetRef(loop), runtime::StorageScope::Create(buffer->scope))) { + if (NeedRelaxThread(GetRef(loop), runtime::StorageScope::Create(buffer.scope()))) { dom_map[loop_var] = IntSetFromMinExtent(loop->min, loop->extent); } } diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc index 07f7b42fe2eb..88c254a8cb5e 100644 --- a/src/tir/transforms/flatten_buffer.cc +++ b/src/tir/transforms/flatten_buffer.cc @@ -127,10 +127,7 @@ class BufferFlattener : public StmtExprMutator { } static Stmt MakeAllocStmt(const Buffer& buffer, Stmt body) { - String storage_scope = buffer->scope; - if (storage_scope.empty()) { - storage_scope = "global"; - } + String storage_scope = buffer.scope(); PrimExpr area = BufferArea(buffer); body = Allocate(buffer->data, buffer->dtype, {area}, const_true(), std::move(body)); body = AttrStmt(buffer->data, attr::storage_scope, StringImm(storage_scope), std::move(body)); diff --git a/src/tir/transforms/inject_copy_intrin.cc b/src/tir/transforms/inject_copy_intrin.cc index f7443c74c0f7..40f0e368d93d 100644 --- a/src/tir/transforms/inject_copy_intrin.cc +++ b/src/tir/transforms/inject_copy_intrin.cc @@ -148,11 +148,9 @@ class CopyIntrinInjector : public StmtMutator { dst_strides.push_back(make_const(DataType::Int(32), 1)); } Buffer dst = Buffer(store->buffer_var, store->value.dtype(), dst_shape, dst_strides, - store_strides[loop_var_size], store->buffer_var->name_hint, - GetStorageScope(store->buffer_var.get()), 0, 0, kDefault); + store_strides[loop_var_size], store->buffer_var->name_hint, 0, 0, kDefault); Buffer src = Buffer(load->buffer_var, load->dtype, src_shape, src_strides, src_elem_offset, - load->buffer_var->name_hint, GetStorageScope(load->buffer_var.get()), 0, 0, - kDefault); + load->buffer_var->name_hint, 0, 0, kDefault); *out = flower_copy_fromto_(src, dst, pad_before, pad_after, pad_value); ICHECK(out->defined()) << "flower function did not return correct stmt"; return true; diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 0db86130a8da..5de22fe8665d 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -198,7 +198,7 @@ class StorageFlattener : public StmtExprMutator { auto new_var = Var(op->buffer->data->name_hint, PointerType(ptr_type->element_type, skey.to_string())); e.buffer = Buffer(new_var, op->buffer->dtype, shape, strides, PrimExpr(), op->buffer->name, - skey.to_string(), align, 0, kDefault); + align, 0, kDefault); buf_map_[key] = e; Stmt body = this->VisitStmt(op->body); @@ -224,7 +224,7 @@ class StorageFlattener : public StmtExprMutator { ret = Allocate(e.buffer->data, storage_type, shape, make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body); } - ret = AttrStmt(e.buffer->data, attr::storage_scope, StringImm(e.buffer->scope), ret); + ret = AttrStmt(e.buffer->data, attr::storage_scope, StringImm(skey.to_string()), ret); if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { ret = AttrStmt(e.buffer->data, tir::attr::buffer_bound, diff --git a/vta/python/vta/transform.py b/vta/python/vta/transform.py index 7c7d02b40fbb..383841f19e34 100644 --- a/vta/python/vta/transform.py +++ b/vta/python/vta/transform.py @@ -495,21 +495,21 @@ def _inject_copy(src, dst, pad_before, pad_after, pad_value): # FIXME: pad_value is ignored... env = get_env() _ = pad_value - if dst.scope == "global": + if dst.scope() == "global": # Store if pad_before or pad_after: raise RuntimeError("Do not support copy into DRAM with pad") - if src.scope == env.acc_scope: + if src.scope() == env.acc_scope: elem_width = env.OUT_WIDTH elem_bytes = env.OUT_ELEM_BYTES mem_type = env.dev.MEM_ID_OUT data_type = "int%d" % env.OUT_WIDTH task_qid = env.dev.QID_STORE_OUT else: - raise RuntimeError("Do not support copy %s->dram" % (src.scope)) + raise RuntimeError("Do not support copy %s->dram" % (src.scope())) _check_compact(src) x_size, y_size, x_stride, offset = _get_2d_pattern( - dst, elem_width, elem_bytes, data_type, src.scope, allow_fold=True + dst, elem_width, elem_bytes, data_type, src.scope(), allow_fold=True ) irb = tvm.tir.ir_builder.create() irb.scope_attr(env.dev.vta_axis, "coproc_scope", env.dev.get_task_qid(task_qid)) @@ -528,27 +528,27 @@ def _inject_copy(src, dst, pad_before, pad_after, pad_value): ) ) return irb.get() - elif src.scope == "global": - if dst.scope == env.acc_scope: + elif src.scope() == "global": + if dst.scope() == env.acc_scope: elem_width = env.ACC_WIDTH elem_bytes = env.ACC_ELEM_BYTES mem_type = env.dev.MEM_ID_ACC data_type = "int%d" % env.ACC_WIDTH task_qid = env.dev.QID_LOAD_OUT - elif dst.scope == env.inp_scope: + elif dst.scope() == env.inp_scope: elem_width = env.INP_WIDTH elem_bytes = env.INP_ELEM_BYTES mem_type = env.dev.MEM_ID_INP data_type = "int%d" % env.INP_WIDTH task_qid = env.dev.QID_LOAD_INP - elif dst.scope == env.wgt_scope: + elif dst.scope() == env.wgt_scope: elem_width = env.WGT_WIDTH elem_bytes = env.WGT_ELEM_BYTES mem_type = env.dev.MEM_ID_WGT data_type = "int%d" % env.WGT_WIDTH task_qid = env.dev.QID_LOAD_WGT else: - raise RuntimeError("Do not support copy dram->%s" % (dst.scope)) + raise RuntimeError("Do not support copy dram->%s" % (dst.scope())) # collect pad statistics if pad_before: assert pad_after @@ -586,7 +586,7 @@ def _inject_copy(src, dst, pad_before, pad_after, pad_value): _check_compact(dst) x_size, y_size, x_stride, offset = _get_2d_pattern( - src, elem_width, elem_bytes, data_type, dst.scope, allow_fold=allow_fold + src, elem_width, elem_bytes, data_type, dst.scope(), allow_fold=allow_fold ) if data_type != src.dtype: @@ -617,7 +617,7 @@ def _inject_copy(src, dst, pad_before, pad_after, pad_value): return irb.get() else: - raise RuntimeError("Do not support copy %s->%s" % (src.scope, dst.scope)) + raise RuntimeError("Do not support copy %s->%s" % (src.scope(), dst.scope())) return tvm.tir.transform.InjectCopyIntrin("dma_copy", _inject_copy)