Skip to content

Commit

Permalink
Adding annotation for tir.allocate
Browse files Browse the repository at this point in the history
This commit is adding annotations for tir.allocate
node to be used as hints for future transformations.

Change-Id: I02a3a875c38c3edd449385da5b741ef4958bb47f
  • Loading branch information
manupak committed Oct 5, 2021
1 parent d9a5ff5 commit 4971d09
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 7 deletions.
14 changes: 12 additions & 2 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -521,20 +521,28 @@ class AllocateNode : public StmtNode {
PrimExpr condition;
/*! \brief The body to be executed. */
Stmt body;
/*!
* \brief Additional annotations about the allocation.
*
* These annotations can be used as auxiliary hint
* to future transformations.
*/
Map<String, ObjectRef> annotations;

void VisitAttrs(AttrVisitor* v) {
v->Visit("buffer_var", &buffer_var);
v->Visit("dtype", &dtype);
v->Visit("extents", &extents);
v->Visit("condition", &condition);
v->Visit("body", &body);
v->Visit("annotations", &annotations);
v->Visit("span", &span);
}

bool SEqualReduce(const AllocateNode* other, SEqualReducer equal) const {
return equal.DefEqual(buffer_var, other->buffer_var) && equal(dtype, other->dtype) &&
equal(extents, other->extents) && equal(condition, other->condition) &&
equal(body, other->body);
equal(body, other->body) && equal(annotations, other->annotations);
}

void SHashReduce(SHashReducer hash_reduce) const {
Expand All @@ -543,6 +551,7 @@ class AllocateNode : public StmtNode {
hash_reduce(extents);
hash_reduce(condition);
hash_reduce(body);
hash_reduce(annotations);
}

/*!
Expand Down Expand Up @@ -570,7 +579,8 @@ class AllocateNode : public StmtNode {
class Allocate : public Stmt {
public:
TVM_DLL Allocate(Var buffer_var, DataType dtype, Array<PrimExpr> extents, PrimExpr condition,
Stmt body, Span span = Span());
Stmt body, Map<String, ObjectRef> annotations = Map<String, ObjectRef>(),
Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(Allocate, Stmt, AllocateNode);
};
Expand Down
16 changes: 14 additions & 2 deletions python/tvm/tir/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,13 +318,25 @@ class Allocate(Stmt):
body : Stmt
The body statement.
annotations: Optional[Mapping[str, Object]]
Additional annotation hints
span : Optional[Span]
The location of this itervar in the source code.
"""

def __init__(self, buffer_var, dtype, extents, condition, body, span=None):
def __init__(self, buffer_var, dtype, extents, condition, body, annotations=None, span=None):
if annotations is None:
annotations = dict()
self.__init_handle_by_constructor__(
_ffi_api.Allocate, buffer_var, dtype, extents, condition, body, span # type: ignore
_ffi_api.Allocate, # type: ignore
buffer_var,
dtype,
extents,
condition,
body,
annotations,
span,
)


Expand Down
7 changes: 4 additions & 3 deletions src/tir/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)

// Allocate
Allocate::Allocate(Var buffer_var, DataType dtype, Array<PrimExpr> extents, PrimExpr condition,
Stmt body, Span span) {
Stmt body, Map<String, ObjectRef> annotations, Span span) {
CHECK(IsPointerType(buffer_var->type_annotation, dtype))
<< "The allocated data type (" << dtype
<< ") does not match the type annotation of the buffer " << buffer_var << " ("
Expand All @@ -354,6 +354,7 @@ Allocate::Allocate(Var buffer_var, DataType dtype, Array<PrimExpr> extents, Prim
node->extents = std::move(extents);
node->condition = std::move(condition);
node->body = std::move(body);
node->annotations = std::move(annotations);
node->span = std::move(span);
data_ = std::move(node);
}
Expand All @@ -375,8 +376,8 @@ int32_t AllocateNode::constant_allocation_size(const Array<PrimExpr>& extents) {

TVM_REGISTER_GLOBAL("tir.Allocate")
.set_body_typed([](Var buffer_var, DataType type, Array<PrimExpr> extents, PrimExpr condition,
Stmt body, Span span) {
return Allocate(buffer_var, type, extents, condition, body, span);
Stmt body, Map<String, ObjectRef> annotations, Span span) {
return Allocate(buffer_var, type, extents, condition, body, annotations, span);
});

TVM_REGISTER_NODE_TYPE(AllocateNode);
Expand Down
23 changes: 23 additions & 0 deletions tests/python/unittest/test_tir_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,5 +473,28 @@ def test_block_blockrealize():
assert output.find("with init()") != -1


def test_tir_allocate():
dtype = "int8"
storage_scope = "global"
ptype = tvm.ir.PointerType(tvm.ir.PrimType(dtype), storage_scope)
a = te.var("buffer", ptype)
allocate = tvm.tir.Allocate(
buffer_var=a,
dtype=dtype,
extents=[2, 2],
condition=tvm.get_global_func("tir.const_true")(dtype, None),
body=tvm.tir.Evaluate(2 + 1),
annotations={
"attr1": "foo",
"attr2": "bar",
},
)
assert allocate.buffer_var == a
assert allocate.dtype == "int8"
assert list(allocate.extents) == [2, 2]
assert allocate.annotations["attr1"] == "foo"
assert allocate.annotations["attr2"] == "bar"


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 4971d09

Please sign in to comment.