Skip to content

Commit

Permalink
[TIR] Get read/write access precisely for opaque access. (apache#11110)
Browse files Browse the repository at this point in the history
* [TIR] Get read/write access precisely for opaque access.

When the opaque access is wrapped with tvm_access_ptr, we can get the access_mask
from tvm_access_ptr in BlockReadWriteDetector and put this opaque access to read_regions
or write_regions according to access_mask.

* [TIR] Add parameter extent for access_ptr.

Co-authored-by: sqing <qing.siqi@intellif.com>
  • Loading branch information
2 people authored and Sergey Shtin committed May 17, 2022
1 parent fc35569 commit 96a274a
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 41 deletions.
5 changes: 3 additions & 2 deletions include/tvm/tir/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,10 +186,11 @@ class Buffer : public ObjectRef {
* \param ptr_type The type of the pointer.
* \param content_lanes The number of lanes for the (data) type.
* \param offset The offset of ptr.
* \param input_extent The extent of ptr.
*/
TVM_DLL PrimExpr access_ptr(int access_mask, DataType ptr_type = DataType::Handle(),
int content_lanes = 1,
PrimExpr offset = IntImm(DataType::Int(32), 0)) const;
int content_lanes = 1, PrimExpr offset = IntImm(DataType::Int(32), 0),
Optional<PrimExpr> input_extent = NullOpt) const;
/*!
* \brief Create an Expr that does a vector load at begin index.
* \param begin The beginning index
Expand Down
10 changes: 8 additions & 2 deletions python/tvm/tir/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class Buffer(Object):
READ = 1
WRITE = 2

def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0):
def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0, extent=None):
"""Get an access pointer to the head of buffer.
This is the recommended method to get buffer data
Expand All @@ -66,6 +66,9 @@ def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0):
The offset of pointer. We can use it to offset by
the number of elements from the address of ptr.
extent: Expr, optional
The extent of pointer.
Examples
--------
.. code-block:: python
Expand All @@ -78,6 +81,8 @@ def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0):
buffer.access_ptr("rw")
# Get access ptr for read with offset
buffer.access_ptr("r", offset = 100)
# Get access ptr for read with extent
buffer.access_ptr("r", extent = 100)
"""
if isinstance(access_mask, string_types):
mask = 0
Expand All @@ -90,8 +95,9 @@ def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0):
raise ValueError("Unknown access_mask %s" % access_mask)
access_mask = mask
offset = convert(offset)
extent = convert(extent)
return _ffi_api.BufferAccessPtr(
self, access_mask, ptr_type, content_lanes, offset # type: ignore
self, access_mask, ptr_type, content_lanes, offset, extent # type: ignore
)

def vload(self, begin, dtype=None):
Expand Down
28 changes: 28 additions & 0 deletions src/tir/analysis/block_access_region_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,34 @@ void BlockReadWriteDetector::VisitStmt_(const IfThenElseNode* op) {
}

void BlockReadWriteDetector::VisitExpr_(const CallNode* op) {
if (op->op.same_as(builtin::tvm_access_ptr())) {
const VarNode* buffer_var = op->args[1].as<VarNode>();
const IntImmNode* access_mask = op->args[4].as<IntImmNode>();
if (buffer_var && access_mask) {
auto it = buffer_var_map_.find(GetRef<Var>(buffer_var));
if (it != buffer_var_map_.end()) {
const Buffer& buffer = (*it).second;
const BufferRegion buffer_region = BufferRegion::FullRegion(buffer);
const Region& region = buffer_region->region;
std::vector<arith::IntSet> int_set;
int_set.reserve(region.size());
for (const Range& range : region) {
int_set.push_back(arith::EvalSet(range, dom_map_));
}
// read access, write access or opaque access
if ((access_mask->value & 1) && (access_mask->value & 2)) {
Update(&opaque_buffers_, &opaque_regions_, buffer, int_set);
} else if (access_mask->value & 1) {
Update(&read_buffers_, &read_regions_, buffer, int_set);
} else if (access_mask->value & 2) {
Update(&writes_buffers_, &write_regions_, buffer, int_set);
}
}
} else {
StmtExprVisitor::VisitExpr_(op);
}
return;
}
if (op->op.same_as(builtin::if_then_else())) {
VisitExpr(op->args[0]);
{
Expand Down
8 changes: 6 additions & 2 deletions src/tir/ir/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -495,8 +495,8 @@ Buffer Buffer::MakeSlice(Array<PrimExpr> begins, Array<PrimExpr> extents) const
return slice;
}

PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes,
PrimExpr offset) const {
PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes, PrimExpr offset,
Optional<PrimExpr> input_extent) const {
const BufferNode* self = operator->();
ICHECK(self != nullptr);
PrimExpr e_dtype;
Expand All @@ -519,6 +519,10 @@ PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lane
} else {
e_dtype = tir::TypeAnnotation(self->dtype);
}

if (input_extent.defined()) {
extent = input_extent.value();
}
Array<PrimExpr> acc_args{e_dtype, self->data, elem_offset, extent,
make_const(DataType::Int(32), access_mask)};
return tir::Call(ptr_type, tir::builtin::tvm_access_ptr(), acc_args);
Expand Down
29 changes: 29 additions & 0 deletions tests/python/unittest/test_tir_analysis_get_block_access_region.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,19 @@ def opaque_access_func() -> None:
)


@T.prim_func
def opaque_access_with_tvm_access_ptr_func() -> None:
A = T.alloc_buffer([1024])
B = T.alloc_buffer([1024])
C = T.alloc_buffer([1024])
with T.block("opaque"):
T.reads(A[0:1024], C[0:1024])
T.writes(B[0:1024], C[0:1024])
T.evaluate(A.access_ptr("r"))
T.evaluate(B.access_ptr("w"))
T.evaluate(C.access_ptr("rw"))


@T.prim_func
def access_in_if_then_else_func() -> None:
A = T.alloc_buffer([8])
Expand Down Expand Up @@ -235,6 +248,21 @@ def test_opaque_access():
tvm.ir.assert_structural_equal(ret0[1], ret1[1])


def test_opaque_access_with_tvm_access_ptr():
block = opaque_access_with_tvm_access_ptr_func.body.block.body.block
alloc_buffers = opaque_access_with_tvm_access_ptr_func.body.block.alloc_buffers
buffer_var_map = {buf.data: buf for buf in alloc_buffers}

ret0 = tir.analysis.get_block_read_write_region(block, buffer_var_map)
ret1 = tir.analysis.get_block_access_region(block, buffer_var_map)
tvm.ir.assert_structural_equal(block.reads, ret0[0])
tvm.ir.assert_structural_equal(block.writes, ret0[1])
with pytest.raises(ValueError):
tvm.ir.assert_structural_equal(ret0[0], ret1[0])
with pytest.raises(ValueError):
tvm.ir.assert_structural_equal(ret0[1], ret1[1])


def test_match_buffer():
root_block = match_buffer_func.body.block
block = root_block.body.body.body.block
Expand Down Expand Up @@ -333,6 +361,7 @@ def test_access_of_decompose_reduction():
test_block_access_region_detector()
test_opaque_block()
test_opaque_access()
test_opaque_access_with_tvm_access_ptr()
test_match_buffer()
test_access_in_if_then_else_func()
test_access_in_branch_func()
Expand Down
6 changes: 6 additions & 0 deletions tests/python/unittest/test_tir_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ def test_buffer_access_ptr_extent():
aptr = Ab.access_ptr("rw", offset=100)
assert tvm.ir.structural_equal(aptr.args[3], Ab.strides[0] * m - 100)

# Test extent from input params
aptr = Ab.access_ptr("rw", extent=200)
assert tvm.ir.structural_equal(aptr.args[3], 200)
aptr = Ab.access_ptr("rw", offset=100, extent=100)
assert tvm.ir.structural_equal(aptr.args[3], 100)


def test_buffer_vload():
m = te.size_var("m")
Expand Down
97 changes: 62 additions & 35 deletions tests/python/unittest/test_tir_schedule_compute_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,7 @@ def opaque_access_load(a: T.handle, c: T.handle) -> None:
vi, vj = T.axis.remap("SS", [i, j])
T.reads(B[0:128, 0:128])
T.writes(C[0:128, 0:128])
T.evaluate(
T.tvm_access_ptr(
T.type_annotation(dtype="float32"), B.data, 0, 128, "r", dtype="handle"
)
)
T.evaluate(B.access_ptr("r", extent=128))
C[vi, vj] = B[vi, vj] + 1.0


Expand All @@ -205,16 +201,8 @@ def opaque_access_store(a: T.handle, c: T.handle) -> None:
vi, vj = T.axis.remap("SS", [i, j])
T.reads(B[0:128, 0:128])
T.writes(C[0:128, 0:128])
T.evaluate(
T.tvm_access_ptr(
T.type_annotation(dtype="float32"), B.data, 0, 128, "r", dtype="handle"
)
)
T.evaluate(
T.tvm_access_ptr(
T.type_annotation(dtype="float32"), C.data, 0, 128, "w", dtype="handle"
)
)
T.evaluate(B.access_ptr("r", extent=128))
T.evaluate(C.access_ptr("w", extent=128))
C[vi, vj] = B[vi, vj] + 1.0


Expand Down Expand Up @@ -296,16 +284,8 @@ def access_opaque_ptr_then_elemwise(a: T.handle, b: T.handle) -> None:
# annotated opaque partial access
T.reads(A[0:512])
T.writes(A_cache[0:512])
T.evaluate(
T.tvm_access_ptr(
T.type_annotation(dtype="float32"), A.data, 0, 512, "r", dtype="handle"
)
)
T.evaluate(
T.tvm_access_ptr(
T.type_annotation(dtype="float32"), A_cache.data, 0, 512, "w", dtype="handle"
)
)
T.evaluate(A.access_ptr("r", extent=512))
T.evaluate(A_cache.access_ptr("w", extent=512))
for i in range(512):
with T.block("BB"):
vi = T.axis.remap("S", [i])
Expand All @@ -325,16 +305,8 @@ def access_opaque_ptr_then_elemwise_inline(a: T.handle, b: T.handle) -> None:
# annotated opaque partial access should be kept
T.reads(A[0:512])
T.writes([A_cache[0:512]])
T.evaluate(
T.tvm_access_ptr(
T.type_annotation(dtype="float32"), A.data, 0, 512, "r", dtype="handle"
)
)
T.evaluate(
T.tvm_access_ptr(
T.type_annotation(dtype="float32"), A_cache.data, 0, 512, "w", dtype="handle"
)
)
T.evaluate(A.access_ptr("r", extent=512))
T.evaluate(A_cache.access_ptr("w", extent=512))
for i in T.serial(0, 512):
with T.block("B"):
vi = T.axis.spatial(512, i)
Expand Down Expand Up @@ -402,6 +374,51 @@ def inline_block_with_init(
)


@T.prim_func
def exp_exp_opaque_access_with_tvm_access_ptr(
lookup_table: T.Buffer[(1024,), "int8"],
x: T.Buffer[(16,), "float16"],
compute: T.Buffer[(16,), "float16"],
) -> None:
compute_1 = T.alloc_buffer([16], dtype="float16")
for i0 in T.serial(16):
with T.block("compute"):
i0_1 = T.axis.spatial(16, i0)
T.reads(x[i0_1])
T.writes(compute_1[i0_1])
compute_1[i0_1] = T.exp(x[i0_1], dtype="float16")
for i0 in T.serial(16):
with T.block("compute_1"):
i0_2 = T.axis.spatial(16, i0)
T.reads(compute_1[i0_2], lookup_table[0:1024])
T.writes(compute[i0_2])
compute[i0_2] = T.exp(
compute_1[i0_2],
lookup_table.access_ptr("r"),
dtype="float16",
)


@T.prim_func
def exp_exp_opaque_access_with_tvm_access_ptr_inlined(
lookup_table: T.Buffer[(1024,), "int8"],
x: T.Buffer[(16,), "float16"],
compute: T.Buffer[(16,), "float16"],
) -> None:
for i0 in T.serial(16):
with T.block("compute_1"):
i0_1 = T.axis.spatial(16, i0)
# Do not put the opaque access to new write region when opaque access
# wrapped with a tvm_access_ptr and the access mask set to "read only"
T.reads(x[i0_1], lookup_table[0:1024])
T.writes(compute[i0_1])
compute[i0_1] = T.exp(
T.exp(x[i0_1], dtype="float16"),
lookup_table.access_ptr("r"),
dtype="float16",
)


# pylint: enable=no-member,invalid-name,unused-variable


Expand Down Expand Up @@ -569,5 +586,15 @@ def test_inline_block_with_init():
sch.compute_inline(block=block)


def test_compute_inline_opaque_access_with_tvm_access_ptr():
"""Test opaque access with tvm_access_ptr after compute inline"""
sch = tir.Schedule(exp_exp_opaque_access_with_tvm_access_ptr, debug_mask="all")
compute = sch.get_block("compute")
sch.compute_inline(compute)
tvm.ir.assert_structural_equal(
exp_exp_opaque_access_with_tvm_access_ptr_inlined, sch.mod["main"]
)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 comments on commit 96a274a

Please sign in to comment.