diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 94ba853c493a..5cd860b8e929 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -521,6 +521,13 @@ class AllocateNode : public StmtNode { PrimExpr condition; /*! \brief The body to be executed. */ Stmt body; + /*! + * \brief Additional annotations about the allocation. + * + * These annotations can be used as auxiliary hint + * to future transformations. + */ + Map annotations; void VisitAttrs(AttrVisitor* v) { v->Visit("buffer_var", &buffer_var); @@ -528,13 +535,14 @@ class AllocateNode : public StmtNode { v->Visit("extents", &extents); v->Visit("condition", &condition); v->Visit("body", &body); + v->Visit("annotations", &annotations); v->Visit("span", &span); } bool SEqualReduce(const AllocateNode* other, SEqualReducer equal) const { return equal.DefEqual(buffer_var, other->buffer_var) && equal(dtype, other->dtype) && equal(extents, other->extents) && equal(condition, other->condition) && - equal(body, other->body); + equal(body, other->body) && equal(annotations, other->annotations); } void SHashReduce(SHashReducer hash_reduce) const { @@ -543,6 +551,7 @@ class AllocateNode : public StmtNode { hash_reduce(extents); hash_reduce(condition); hash_reduce(body); + hash_reduce(annotations); } /*! @@ -570,7 +579,8 @@ class AllocateNode : public StmtNode { class Allocate : public Stmt { public: TVM_DLL Allocate(Var buffer_var, DataType dtype, Array extents, PrimExpr condition, - Stmt body, Span span = Span()); + Stmt body, Map annotations = Map(), + Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Allocate, Stmt, AllocateNode); }; diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index d57077f08b52..de200d5eabdd 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -318,13 +318,25 @@ class Allocate(Stmt): body : Stmt The body statement. + annotations: Optional[Mapping[str, Object]] + Additional annotation hints + span : Optional[Span] The location of this itervar in the source code. """ - def __init__(self, buffer_var, dtype, extents, condition, body, span=None): + def __init__(self, buffer_var, dtype, extents, condition, body, annotations=None, span=None): + if annotations is None: + annotations = dict() self.__init_handle_by_constructor__( - _ffi_api.Allocate, buffer_var, dtype, extents, condition, body, span # type: ignore + _ffi_api.Allocate, # type: ignore + buffer_var, + dtype, + extents, + condition, + body, + annotations, + span, ) diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index d59c94dc5753..0d42c20c2822 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -333,7 +333,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // Allocate Allocate::Allocate(Var buffer_var, DataType dtype, Array extents, PrimExpr condition, - Stmt body, Span span) { + Stmt body, Map annotations, Span span) { CHECK(IsPointerType(buffer_var->type_annotation, dtype)) << "The allocated data type (" << dtype << ") does not match the type annotation of the buffer " << buffer_var << " (" @@ -354,6 +354,7 @@ Allocate::Allocate(Var buffer_var, DataType dtype, Array extents, Prim node->extents = std::move(extents); node->condition = std::move(condition); node->body = std::move(body); + node->annotations = std::move(annotations); node->span = std::move(span); data_ = std::move(node); } @@ -375,8 +376,8 @@ int32_t AllocateNode::constant_allocation_size(const Array& extents) { TVM_REGISTER_GLOBAL("tir.Allocate") .set_body_typed([](Var buffer_var, DataType type, Array extents, PrimExpr condition, - Stmt body, Span span) { - return Allocate(buffer_var, type, extents, condition, body, span); + Stmt body, Map annotations, Span span) { + return Allocate(buffer_var, type, extents, condition, body, annotations, span); }); TVM_REGISTER_NODE_TYPE(AllocateNode); diff --git a/tests/python/unittest/test_tir_nodes.py b/tests/python/unittest/test_tir_nodes.py index de94464187b0..848b3eed07ea 100644 --- a/tests/python/unittest/test_tir_nodes.py +++ b/tests/python/unittest/test_tir_nodes.py @@ -473,5 +473,28 @@ def test_block_blockrealize(): assert output.find("with init()") != -1 +def test_tir_allocate(): + dtype = "int8" + storage_scope = "global" + ptype = tvm.ir.PointerType(tvm.ir.PrimType(dtype), storage_scope) + a = te.var("buffer", ptype) + allocate = tvm.tir.Allocate( + buffer_var=a, + dtype=dtype, + extents=[2, 2], + condition=tvm.get_global_func("tir.const_true")(dtype, None), + body=tvm.tir.Evaluate(2 + 1), + annotations={ + "attr1": "foo", + "attr2": "bar", + }, + ) + assert allocate.buffer_var == a + assert allocate.dtype == "int8" + assert list(allocate.extents) == [2, 2] + assert allocate.annotations["attr1"] == "foo" + assert allocate.annotations["attr2"] == "bar" + + if __name__ == "__main__": pytest.main([__file__])