Skip to content

Commit

Permalink
[TIR]Add parameter extent for access_ptr.
Browse files Browse the repository at this point in the history
  • Loading branch information
sqing committed Apr 26, 2022
1 parent 42dc437 commit 2998158
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 38 deletions.
5 changes: 3 additions & 2 deletions include/tvm/tir/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,10 +186,11 @@ class Buffer : public ObjectRef {
* \param ptr_type The type of the pointer.
* \param content_lanes The number of lanes for the (data) type.
* \param offset The offset of ptr.
* \param extent The extent of ptr.
*/
TVM_DLL PrimExpr access_ptr(int access_mask, DataType ptr_type = DataType::Handle(),
int content_lanes = 1,
PrimExpr offset = IntImm(DataType::Int(32), 0)) const;
int content_lanes = 1, PrimExpr offset = IntImm(DataType::Int(32), 0),
PrimExpr extent = PrimExpr(ObjectPtr<Object>(nullptr))) const;
/*!
* \brief Create an Expr that does a vector load at begin index.
* \param begin The beginning index
Expand Down
10 changes: 8 additions & 2 deletions python/tvm/tir/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class Buffer(Object):
READ = 1
WRITE = 2

def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0):
def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0, extent=None):
"""Get an access pointer to the head of buffer.
This is the recommended method to get buffer data
Expand All @@ -66,6 +66,9 @@ def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0):
The offset of pointer. We can use it to offset by
the number of elements from the address of ptr.
extent: Expr, optional
The extent of pointer.
Examples
--------
.. code-block:: python
Expand All @@ -78,6 +81,8 @@ def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0):
buffer.access_ptr("rw")
# Get access ptr for read with offset
buffer.access_ptr("r", offset = 100)
# Get access ptr for read with extent
buffer.access_ptr("r", extent = 100)
"""
if isinstance(access_mask, string_types):
mask = 0
Expand All @@ -90,8 +95,9 @@ def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0):
raise ValueError("Unknown access_mask %s" % access_mask)
access_mask = mask
offset = convert(offset)
extent = convert(extent)
return _ffi_api.BufferAccessPtr(
self, access_mask, ptr_type, content_lanes, offset # type: ignore
self, access_mask, ptr_type, content_lanes, offset, extent # type: ignore
)

def vload(self, begin, dtype=None):
Expand Down
8 changes: 6 additions & 2 deletions src/tir/ir/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -495,8 +495,8 @@ Buffer Buffer::MakeSlice(Array<PrimExpr> begins, Array<PrimExpr> extents) const
return slice;
}

PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes,
PrimExpr offset) const {
PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes, PrimExpr offset,
PrimExpr input_extent) const {
const BufferNode* self = operator->();
ICHECK(self != nullptr);
PrimExpr e_dtype;
Expand All @@ -519,6 +519,10 @@ PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lane
} else {
e_dtype = tir::TypeAnnotation(self->dtype);
}

if (!input_extent.same_as(PrimExpr(ObjectPtr<Object>(nullptr)))) {
extent = input_extent;
}
Array<PrimExpr> acc_args{e_dtype, self->data, elem_offset, extent,
make_const(DataType::Int(32), access_mask)};
return tir::Call(ptr_type, tir::builtin::tvm_access_ptr(), acc_args);
Expand Down
6 changes: 6 additions & 0 deletions tests/python/unittest/test_tir_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ def test_buffer_access_ptr_extent():
aptr = Ab.access_ptr("rw", offset=100)
assert tvm.ir.structural_equal(aptr.args[3], Ab.strides[0] * m - 100)

# Test extent from input params
aptr = Ab.access_ptr("rw", extent=200)
assert tvm.ir.structural_equal(aptr.args[3], 200)
aptr = Ab.access_ptr("rw", offset=100, extent=100)
assert tvm.ir.structural_equal(aptr.args[3], 100)


def test_buffer_vload():
m = te.size_var("m")
Expand Down
39 changes: 7 additions & 32 deletions tests/python/unittest/test_tir_schedule_compute_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,7 @@ def opaque_access_load(a: T.handle, c: T.handle) -> None:
vi, vj = T.axis.remap("SS", [i, j])
T.reads(B[0:128, 0:128])
T.writes(C[0:128, 0:128])
T.evaluate(
T.tvm_access_ptr(
T.type_annotation(dtype="float32"), B.data, 0, 128, 1, dtype="handle"
)
)
T.evaluate(B.access_ptr("r", extent=128))
C[vi, vj] = B[vi, vj] + 1.0


Expand All @@ -205,16 +201,8 @@ def opaque_access_store(a: T.handle, c: T.handle) -> None:
vi, vj = T.axis.remap("SS", [i, j])
T.reads(B[0:128, 0:128])
T.writes(C[0:128, 0:128])
T.evaluate(
T.tvm_access_ptr(
T.type_annotation(dtype="float32"), B.data, 0, 128, 1, dtype="handle"
)
)
T.evaluate(
T.tvm_access_ptr(
T.type_annotation(dtype="float32"), C.data, 0, 128, 2, dtype="handle"
)
)
T.evaluate(B.access_ptr("r", extent=128))
T.evaluate(C.access_ptr("w", extent=128))
C[vi, vj] = B[vi, vj] + 1.0


Expand Down Expand Up @@ -296,14 +284,8 @@ def access_opaque_ptr_then_elemwise(a: T.handle, b: T.handle) -> None:
# annotated opaque partial access
T.reads(A[0:512])
T.writes(A_cache[0:512])
T.evaluate(
T.tvm_access_ptr(T.type_annotation(dtype="float32"), A.data, 0, 512, 1, dtype="handle")
)
T.evaluate(
T.tvm_access_ptr(
T.type_annotation(dtype="float32"), A_cache.data, 0, 512, 2, dtype="handle"
)
)
T.evaluate(A.access_ptr("r", extent=512))
T.evaluate(A_cache.access_ptr("w", extent=512))
for i in range(512):
with T.block("BB"):
vi = T.axis.remap("S", [i])
Expand All @@ -323,14 +305,8 @@ def access_opaque_ptr_then_elemwise_inline(a: T.handle, b: T.handle) -> None:
# annotated opaque partial access should be kept
T.reads(A[0:512])
T.writes([A_cache[0:512]])
T.evaluate(
T.tvm_access_ptr(T.type_annotation(dtype="float32"), A.data, 0, 512, 1, dtype="handle")
)
T.evaluate(
T.tvm_access_ptr(
T.type_annotation(dtype="float32"), A_cache.data, 0, 512, 2, dtype="handle"
)
)
T.evaluate(A.access_ptr("r", extent=512))
T.evaluate(A_cache.access_ptr("w", extent=512))
for i in T.serial(0, 512):
with T.block("B"):
vi = T.axis.spatial(512, i)
Expand Down Expand Up @@ -614,7 +590,6 @@ def test_compute_inline_opaque_access_with_tvm_access_ptr():
sch = tir.Schedule(exp_exp_opaque_access_with_tvm_access_ptr, debug_mask="all")
compute = sch.get_block("compute")
sch.compute_inline(compute)
print(sch.mod.script())
tvm.ir.assert_structural_equal(
exp_exp_opaque_access_with_tvm_access_ptr_inlined, sch.mod["main"]
)
Expand Down

0 comments on commit 2998158

Please sign in to comment.