diff --git a/src/tir/transforms/lower_match_buffer.cc b/src/tir/transforms/lower_match_buffer.cc index 2f8fbe0ea6e7..b48749c4c77c 100644 --- a/src/tir/transforms/lower_match_buffer.cc +++ b/src/tir/transforms/lower_match_buffer.cc @@ -188,25 +188,25 @@ class MatchBufferLower : public StmtExprMutator { Load load = Downcast(source_buffer.vload(indices, source_buffer->dtype)); Bind(buffer->elem_offset, load->index, buffer->name + ".elem_offset"); CHECK(analyzer_.CanProve(truncmod(buffer->elem_offset, buffer->offset_factor) == 0)) - << "The source elem_offset " << buffer->elem_offset - << " does not satisfy the offset_factor " << buffer->offset_factor << "."; + << "The source elem_offset " << load->index << " does not satisfy the offset_factor " + << buffer->offset_factor << "."; } // Step 2.3. Check and update strides // Check if target buffer strides are defined + ICHECK(source->region.size() >= buffer->shape.size()); + int offset = source->region.size() - buffer->shape.size(); if (!buffer->strides.empty()) { ICHECK_EQ(buffer->strides.size(), buffer->shape.size()); PrimExpr stride = make_const(DataType::Int(32), 1); for (size_t i = buffer->shape.size(); i > 0; --i) { - const PrimExpr& shape = source_buffer->shape[i - 1]; + const PrimExpr& shape = source_buffer->shape[i - 1 + offset]; Bind(buffer->strides[i - 1], stride, buffer->name + ".strides_" + std::to_string(i - 1)); stride *= shape; } } // Step 2.4. Check and update shape - ICHECK(source->region.size() >= buffer->shape.size()); - size_t offset = source->region.size() - buffer->shape.size(); for (size_t i = 0; i < buffer->shape.size(); ++i) { const Range& range = source->region[i + offset]; Bind(buffer->shape[i], range->extent, buffer->name + ".shape_" + std::to_string(i)); diff --git a/tests/python/unittest/test_tir_lower_match_buffer.py b/tests/python/unittest/test_tir_lower_match_buffer.py index 78a8c5117849..efb2073e0862 100644 --- a/tests/python/unittest/test_tir_lower_match_buffer.py +++ b/tests/python/unittest/test_tir_lower_match_buffer.py @@ -156,6 +156,54 @@ def transformed_opaque_access(a: ty.handle, b: ty.handle) -> None: ) +@tvm.script.tir +def high_dim_opaque_access(a: ty.handle) -> None: + A = tir.match_buffer(a, (16, 32, 64)) + for i, j, k in tir.grid(16, 2, 4): + with tir.block([]): + As_0 = tir.var("int32") + As_1 = tir.var("int32") + tir.reads([]) + tir.writes(A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16]) + sub_A = tir.match_buffer( + A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16], + (16, 16), + strides=[As_0, As_1], + offset_factor=1, + ) + tir.evaluate( + tir.intrin_test( + sub_A.data, + sub_A.elem_offset, + sub_A.strides[0], + sub_A.strides[1], + sub_A.shape[0], + sub_A.shape[1], + dtype="handle", + ) + ) + + +@tvm.script.tir +def transformed_high_dim_opaque_access(a: ty.handle) -> None: + A = tir.match_buffer(a, (16, 32, 64)) + for i, j, k in tir.grid(16, 2, 4): + with tir.block([]): + tir.reads([]) + tir.writes(A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16]) + tir.evaluate( + tir.intrin_test( + A.data, + i * 2048 + j * 1024 + k * 16, + 64, + 1, + 16, + 16, + dtype="handle", + ) + ) + + @tvm.script.tir def recursive_match(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (64, 64, 64)) @@ -419,6 +467,10 @@ def test_opaque_access(): _check(opaque_access, transformed_opaque_access) +def test_high_dim_opaque_access(): + _check(high_dim_opaque_access, transformed_high_dim_opaque_access) + + def test_recursive_match(): _check(recursive_match, transformed_recursive_match) @@ -447,6 +499,7 @@ def test_fail_match_func_param(): if __name__ == "__main__": test_buffer_load_store() test_opaque_access() + test_high_dim_opaque_access() test_recursive_match() test_symbolic_match() test_rank0_buffer()