Skip to content

Commit

Permalink
[TIR][LowerMatchBuffer] Fix lowering strides when source buffer has n…
Browse files Browse the repository at this point in the history
…on-empty strides (#9166)
  • Loading branch information
vinx13 authored Oct 1, 2021
1 parent 61fbda9 commit 62a7fb7
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 5 deletions.
18 changes: 13 additions & 5 deletions src/tir/transforms/lower_match_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,11 +198,19 @@ class MatchBufferLower : public StmtExprMutator {
int offset = source->region.size() - buffer->shape.size();
if (!buffer->strides.empty()) {
ICHECK_EQ(buffer->strides.size(), buffer->shape.size());
PrimExpr stride = make_const(DataType::Int(32), 1);
for (size_t i = buffer->shape.size(); i > 0; --i) {
const PrimExpr& shape = source_buffer->shape[i - 1 + offset];
Bind(buffer->strides[i - 1], stride, buffer->name + ".strides_" + std::to_string(i - 1));
stride *= shape;
if (source_buffer->strides.empty()) {
PrimExpr stride = make_const(DataType::Int(32), 1);
for (size_t i = buffer->shape.size(); i > 0; --i) {
const PrimExpr& shape = source_buffer->shape[i - 1 + offset];
Bind(buffer->strides[i - 1], stride, buffer->name + ".strides_" + std::to_string(i - 1));
stride *= shape;
}
} else {
ICHECK_EQ(buffer->shape.size() + offset, source_buffer->strides.size());
for (size_t i = buffer->shape.size(); i > 0; --i) {
const PrimExpr& stride = source_buffer->strides[i - 1 + offset];
Bind(buffer->strides[i - 1], stride, buffer->name + ".strides_" + std::to_string(i - 1));
}
}
}

Expand Down
52 changes: 52 additions & 0 deletions tests/python/unittest/test_tir_lower_match_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,54 @@ def transformed_high_dim_opaque_access(a: ty.handle) -> None:
)


@tvm.script.tir
def high_dim_opaque_access_with_source_strides(a: ty.handle) -> None:
A = tir.match_buffer(a, (16, 32, 64), strides=[2576, 80, 1])
for i, j, k in tir.grid(16, 2, 4):
with tir.block([]):
As_0 = tir.var("int32")
As_1 = tir.var("int32")
tir.reads([])
tir.writes(A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16])
sub_A = tir.match_buffer(
A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16],
(16, 16),
strides=[As_0, As_1],
offset_factor=1,
)
tir.evaluate(
tir.intrin_test(
sub_A.data,
sub_A.elem_offset,
sub_A.strides[0],
sub_A.strides[1],
sub_A.shape[0],
sub_A.shape[1],
dtype="handle",
)
)


@tvm.script.tir
def transformed_high_dim_opaque_access_with_source_strides(a: ty.handle) -> None:
A = tir.match_buffer(a, (16, 32, 64), strides=[2576, 80, 1])
for i, j, k in tir.grid(16, 2, 4):
with tir.block([]):
tir.reads([])
tir.writes(A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16])
tir.evaluate(
tir.intrin_test(
A.data,
i * 2576 + j * 1280 + k * 16,
80,
1,
16,
16,
dtype="handle",
)
)


@tvm.script.tir
def recursive_match(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (64, 64, 64))
Expand Down Expand Up @@ -469,6 +517,10 @@ def test_opaque_access():

def test_high_dim_opaque_access():
_check(high_dim_opaque_access, transformed_high_dim_opaque_access)
_check(
high_dim_opaque_access_with_source_strides,
transformed_high_dim_opaque_access_with_source_strides,
)


def test_recursive_match():
Expand Down

0 comments on commit 62a7fb7

Please sign in to comment.