diff --git a/python/tvm/script/context_maintainer.py b/python/tvm/script/context_maintainer.py index ae3e9d885f1a2..44c92b792f122 100644 --- a/python/tvm/script/context_maintainer.py +++ b/python/tvm/script/context_maintainer.py @@ -57,7 +57,7 @@ def example_func(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # match_buffers of the block, # which bind a sub-region of source buffer into a new buffer - D = tir.match_buffer_region(C[vi, vj]) + D = tir.match_buffer(C[vi, vj], ()) # init part of the block, executed when all reduce axes are the beginning value with tir.init(): @@ -65,13 +65,13 @@ def example_func(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # block body CC[0, 0] = A[vi, vk] * B[vj, vk] - D[0, 0] += CC[0, 0] # The same as C[vi, vj] += CC[0, 0] + D[()] += CC[0, 0] # The same as C[vi, vj] += CC[0, 0] """ alloc_buffers: List[Buffer] = [] """List[Buffer]: list of tir.alloc_buffer statements in the block signature""" match_buffers: List[MatchBufferRegion] = [] - """List[MatchBufferRegion]: list of tir.match_buffer_region statements in the block signature""" + """List[MatchBufferRegion]: list of tir.match_buffer statements in the block signature""" iter_bindings: Mapping[Var, PrimExpr] = {} """Mapping[Var, PrimExpr]: map of block iter var to its values""" reads: Optional[List[BufferSlice]] = None diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index dbd910dc790be..3e47eb5a4254b 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -652,7 +652,7 @@ def LowerMatchBuffer(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.LowerMatchBuffer() + return _ffi_api.LowerMatchBuffer() # type: ignore def FlattenBuffer(): diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index 0f3b89932b683..f0b82291d35ae 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -598,8 +598,8 @@ Doc TIRTextPrinter::VisitStmt_(const BlockRealizeNode* op) { << Print(alloc_buf->shape) << ")" << Doc::NewLine(); } for (const auto& match_buf : block_op->match_buffers) { - body << AllocBuf(match_buf->buffer) << " = match_buffer_region(" << Print(match_buf->source) - << ")" << Doc::NewLine(); + body << AllocBuf(match_buf->buffer) << " = match_buffer(" << Print(match_buf->source) << ")" + << Doc::NewLine(); } if (block_op->init.defined()) { Doc init_block; diff --git a/src/tir/analysis/buffer_access_lca_detector.cc b/src/tir/analysis/buffer_access_lca_detector.cc index 8f39de4c96a42..e680d689735d6 100644 --- a/src/tir/analysis/buffer_access_lca_detector.cc +++ b/src/tir/analysis/buffer_access_lca_detector.cc @@ -86,23 +86,16 @@ class LCADetector : public StmtExprVisitor { buffer_var_map_.emplace(buf->data.get(), buf.get()); } + const ScopeInfo* parent_scope = ancestor_scopes_.back(); + auto* current_scope = arena_.make(parent_scope, op, n); + + ancestor_scopes_.push_back(current_scope); // Update match_buffers for (const MatchBufferRegion& match_buffer : op->match_buffers) { - const Buffer& target_buffer = match_buffer->buffer; - buffer_var_map_.emplace(target_buffer->data.get(), target_buffer.get()); - - const Buffer& source_buffer = match_buffer->source->buffer; - auto it = match_buffers_.find(source_buffer.get()); - if (it != match_buffers_.end()) { - match_buffers_[target_buffer.get()] = it->second; - } else { - match_buffers_[target_buffer.get()] = source_buffer.get(); - } + UpdateBufferLCA(match_buffer->source->buffer.get()); + match_buffers_.insert(match_buffer->buffer.get()); } - const ScopeInfo* parent_scope = ancestor_scopes_.back(); - auto* current_scope = arena_.make(parent_scope, op, n); - ancestor_scopes_.push_back(current_scope); StmtExprVisitor::VisitStmt_(op); ancestor_scopes_.pop_back(); } @@ -144,12 +137,11 @@ class LCADetector : public StmtExprVisitor { } void UpdateBufferLCA(const BufferNode* buffer) { - auto it = match_buffers_.find(buffer); - if (it != match_buffers_.end()) { - buffer = it->second; + if (match_buffers_.find(buffer) == match_buffers_.end()) { + // Ingore buffer created by block match_buffer + const ScopeInfo*& lca = buffer_lca_[buffer]; + lca = LowestCommonAncestor(lca, ancestor_scopes_.back()); } - const ScopeInfo*& lca = buffer_lca_[buffer]; - lca = LowestCommonAncestor(lca, ancestor_scopes_.back()); } static const ScopeInfo* LowestCommonAncestor(const ScopeInfo* lhs, const ScopeInfo* rhs) { @@ -184,7 +176,7 @@ class LCADetector : public StmtExprVisitor { /*! \brief The map from Buffer data to the Buffer. */ std::unordered_map buffer_var_map_ = {}; /*! \brief The match buffers inside blocks. */ - std::unordered_map match_buffers_ = {}; + std::unordered_set match_buffers_ = {}; /*! \brief Internal arena. */ support::Arena arena_; }; diff --git a/src/tir/ir/script/script_complete.cc b/src/tir/ir/script/script_complete.cc index c15b3bb47bf4d..f265a8ae2b1b9 100644 --- a/src/tir/ir/script/script_complete.cc +++ b/src/tir/ir/script/script_complete.cc @@ -36,15 +36,12 @@ namespace tir { /*! \brief Generate surrounding loops automatically */ class ScriptCompleter : public StmtMutator { public: - explicit ScriptCompleter(Map* buffer_var_map, bool contain_root) - : buffer_var_map_(buffer_var_map), contain_root_(contain_root) {} + explicit ScriptCompleter(Map* buffer_var_map) : buffer_var_map_(buffer_var_map) {} /*! \brief Whether the stmt contains at least one block. */ bool contains_block = false; private: Map* buffer_var_map_; - bool contain_root_; - bool visited_root_ = false; Stmt VisitStmt_(const BlockRealizeNode* op) override { contains_block = true; Stmt body = StmtMutator::VisitStmt_(op); @@ -65,17 +62,23 @@ class ScriptCompleter : public StmtMutator { } Stmt VisitStmt_(const BlockNode* op) override { - bool is_root_block = contain_root_ && !visited_root_; - visited_root_ = true; // Buffers allocated in the block can be accessed by its body. for (const auto& alloc_buffer : op->alloc_buffers) { buffer_var_map_->Set(alloc_buffer->data, alloc_buffer); } + for (const auto& match_buffer : op->match_buffers) { + const Buffer& target_buffer = match_buffer->buffer; + buffer_var_map_->Set(target_buffer->data, target_buffer); + } Block block = Downcast(StmtMutator::VisitStmt_(op)); // Remove buffers allocated inside block to detect its access region for (const auto& alloc_buffer : op->alloc_buffers) { buffer_var_map_->erase(alloc_buffer->data); } + for (const auto& match_buffer : op->match_buffers) { + const Buffer& target_buffer = match_buffer->buffer; + buffer_var_map_->erase(target_buffer->data); + } // Get access detection mask // 0 for provided region, 1 and 3 for need detect read, 2 and 3 for need detect write int mask = 0; @@ -85,13 +88,6 @@ class ScriptCompleter : public StmtMutator { } // ignore root block or blocks which already has reads/writes regions if (mask != 0) { - if (op->iter_vars.empty()) { - // non-root opaque block is not allowed - CHECK(is_root_block) - << "ValueError: Can not auto detect buffer access region for an opaque block. Please " - "annotate the access region manually."; - return std::move(block); - } auto access_region = GetBlockAccessRegion(block, *buffer_var_map_); const Array& reads = access_region[0]; const Array& writes = access_region[1]; @@ -122,7 +118,7 @@ PrimFunc ScriptComplete(PrimFunc func, const Array& root_allocates) { } bool contain_root = root_allocates.empty() && func->body->IsInstance() && Downcast(func->body)->block->iter_vars.empty(); - ScriptCompleter script_completer(&buffer_var_map, contain_root); + ScriptCompleter script_completer(&buffer_var_map); // generate surrounding loops automatically Stmt res = script_completer(func->body); // generate root block automatically diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index d72fd8f72d6d6..ef59213d3610a 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -697,9 +697,9 @@ MatchBufferRegion::MatchBufferRegion(Buffer buffer, BufferRegion source) { const Buffer& source_buffer = source->buffer; arith::Analyzer analyzer; // Check scope and dtype - CHECK_EQ(buffer->scope, source_buffer->scope) - << "MatchBuffer " << buffer << " scope mismatch:" << buffer->scope << " vs. " - << source_buffer->scope; + CHECK_EQ(buffer.scope(), source_buffer.scope()) + << "MatchBuffer " << buffer << " scope mismatch:" << buffer.scope() << " vs. " + << source_buffer.scope(); CHECK_EQ(buffer->dtype, source_buffer->dtype) << "MatchBuffer " << buffer << " data type mismatch:" << buffer->dtype << " vs. " << source_buffer->dtype; @@ -798,7 +798,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); - p->stream << op->buffer->name << " = match_buffer_region("; + p->stream << op->buffer->name << " = match_buffer("; p->Print(op->source); p->stream << ")\n"; }); diff --git a/src/tir/transforms/lower_match_buffer.cc b/src/tir/transforms/lower_match_buffer.cc index 9ad115c647817..2d4f6a8d55a33 100644 --- a/src/tir/transforms/lower_match_buffer.cc +++ b/src/tir/transforms/lower_match_buffer.cc @@ -135,9 +135,9 @@ class MatchBufferLower : public StmtExprMutator { const Buffer& source_buffer = source->buffer; // Step.1.1. Check scope & dtype - ICHECK_EQ(buffer->scope, source_buffer->scope) - << "MatchBuffer " << buffer << " scope mismatch:" << buffer->scope << "vs." - << source_buffer->scope; + ICHECK_EQ(buffer.scope(), source_buffer.scope()) + << "MatchBuffer " << buffer << " scope mismatch:" << buffer.scope() << "vs." + << source_buffer.scope(); ICHECK_EQ(buffer->dtype, source_buffer->dtype) << "MatchBuffer " << buffer << " data type mismatch:" << buffer->dtype << "vs." << source_buffer->dtype; @@ -251,4 +251,4 @@ TVM_REGISTER_GLOBAL("tir.transform.LowerMatchBuffer").set_body_typed(LowerMatchB } // namespace transform } // namespace tir -} // namespace tvm \ No newline at end of file +} // namespace tvm diff --git a/tests/python/integration/test_lower.py b/tests/python/integration/test_lower.py index 9053b35348f18..3fa4795870d59 100644 --- a/tests/python/integration/test_lower.py +++ b/tests/python/integration/test_lower.py @@ -302,6 +302,7 @@ def tensorcore_gemm(a: ty.handle, b: ty.handle, c: ty.handle) -> None: ) +@tvm.testing.requires_cuda def test_gemm_tensorcore(): dev = tvm.device("cuda", 0) a_np = np.random.uniform(size=(1024, 1024)).astype("float16") @@ -310,7 +311,6 @@ def test_gemm_tensorcore(): a = tvm.nd.array(a_np, dev) b = tvm.nd.array(b_np, dev) c = tvm.nd.array(np.zeros((1024, 1024), dtype="float32"), dev) - print(tvm.script.asscript(tvm.lower(tensorcore_gemm, simple_mode=True))) f = tvm.build(tensorcore_gemm, target="cuda", name="dense") f(a, b, c) tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3) @@ -324,4 +324,4 @@ def test_gemm_tensorcore(): if __name__ == "__main__": - test_gemm_tensorcore() \ No newline at end of file + test_gemm_tensorcore() diff --git a/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py b/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py index c7cf2f6edfb17..8c2b2710f1ba5 100644 --- a/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py +++ b/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py @@ -77,12 +77,11 @@ def match_buffer_func(a: ty.handle, b: ty.handle) -> None: with tir.block([8, 8], "block") as [vi, vj]: tir.reads(B[vi * 16 + 2 : vi * 16 + 12, vj * 16 + 2 : vj * 16 + 16]) tir.writes(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - AA = tir.match_buffer(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], (16, 16)) B0 = tir.match_buffer(B[vi * 16 + 2 : vi * 16 + 6, vj * 16 + 2 : vj * 16 + 6], (4, 4)) B1 = tir.match_buffer(B[vi * 16 + 8 : vi * 16 + 12, vj * 16 + 8 : vj * 16 + 16], (4, 8)) with tir.block([16, 16], "AAA") as [i, j]: - AAA = tir.match_buffer(AA[i, j], ()) - AAA[()] = 1.0 + AA = tir.match_buffer(A[i, j], ()) + AA[()] = 1.0 tir.evaluate(B0.data) tir.evaluate(B1.data) diff --git a/tests/python/unittest/test_tir_analysis_get_block_access_region.py b/tests/python/unittest/test_tir_analysis_get_block_access_region.py index 70de4372805fb..7641f0ac46cbc 100644 --- a/tests/python/unittest/test_tir_analysis_get_block_access_region.py +++ b/tests/python/unittest/test_tir_analysis_get_block_access_region.py @@ -62,6 +62,25 @@ def match_buffer_func() -> None: tir.evaluate(B1.data) +@tvm.script.tir +def opaque_block_func() -> None: + with tir.block([], "root"): + A = tir.alloc_buffer((16, 16), "float32") + B = tir.alloc_buffer((16, 16), "float32") + tir.reads([]) + tir.writes([]) + # Need add read/write region manually to avoid triggering block access region detector + for i in range(0, 16): + with tir.block([]): + tir.reads(A[i, 0:16]) + tir.writes([B[i, 0:16]]) + for j in range(0, 16): + with tir.block([]): + tir.reads(A[i, j]) + tir.writes(B[i, j]) + B[i, j] = A[i, j] + 1.0 + + def test_block_access_region_detector(): block = func.body.block.body.block alloc_buffers = func.body.block.alloc_buffers @@ -76,6 +95,21 @@ def test_block_access_region_detector(): ) +def test_opaque_block(): + alloc_buffers = opaque_block_func.body.block.alloc_buffers + buffer_var_map = {buf.data: buf for buf in alloc_buffers} + + block0 = opaque_block_func.body.block.body.body.block + ret = tir.analysis.get_block_access_region(block0, buffer_var_map) + tvm.ir.assert_structural_equal(block0.reads, ret[0]) + tvm.ir.assert_structural_equal(block0.writes, ret[1]) + + block1 = block0.body.body.block + ret = tir.analysis.get_block_access_region(block1, buffer_var_map) + tvm.ir.assert_structural_equal(block1.reads, ret[0]) + tvm.ir.assert_structural_equal(block1.writes, ret[1]) + + def test_match_buffer(): root_block = match_buffer_func.body.block block = root_block.body.body.body.block @@ -97,4 +131,5 @@ def test_match_buffer(): if __name__ == "__main__": test_block_access_region_detector() + test_opaque_block() test_match_buffer() diff --git a/tests/python/unittest/test_tir_lower_match_buffer.py b/tests/python/unittest/test_tir_lower_match_buffer.py index 0c297820a850b..a0f188ed867ba 100644 --- a/tests/python/unittest/test_tir_lower_match_buffer.py +++ b/tests/python/unittest/test_tir_lower_match_buffer.py @@ -65,8 +65,8 @@ def transformed_buffer_load_store(a: ty.handle, c: ty.handle) -> None: A[i * 4 + ii, j, k * 2 + kk] += C[i * 4 + ii, k * 2 + kk] -@tvm.ir.register_op_attr("tir.test_intrin", "") -def test_intrin(data, elem_offset, stride_0, stride_1, shape_0, shape_1): +@tvm.ir.register_op_attr("tir.intrin_test", "") +def intrin_test(data, elem_offset, stride_0, stride_1, shape_0, shape_1): return 0 @@ -85,7 +85,7 @@ def opaque_access(a: ty.handle, b: ty.handle) -> None: offset_factor=1, ) tir.evaluate( - tir.test_intrin( + tir.intrin_test( sub_A.data, sub_A.elem_offset, sub_A.strides[0], @@ -108,7 +108,7 @@ def opaque_access(a: ty.handle, b: ty.handle) -> None: offset_factor=1, ) tir.evaluate( - tir.test_intrin( + tir.intrin_test( sub_B.data, sub_B.elem_offset, sub_B.strides[0], @@ -129,7 +129,7 @@ def transformed_opaque_access(a: ty.handle, b: ty.handle) -> None: tir.reads([]) tir.writes(A[i * 16 : i * 16 + 16, j, k * 16 : k * 16 + 16]) tir.evaluate( - tir.test_intrin( + tir.intrin_test( A.data, i * 131072 + j * 128 + k * 16, 8192, @@ -144,7 +144,7 @@ def transformed_opaque_access(a: ty.handle, b: ty.handle) -> None: tir.reads([]) tir.writes(B[i, j * 32 : j * 32 + 32, k * 8 : k * 8 + 8]) tir.evaluate( - tir.test_intrin( + tir.intrin_test( B.data, i * 4096 + j * 2048 + k * 8, 64, @@ -205,7 +205,7 @@ def recursive_match(a: ty.handle, b: ty.handle) -> None: offset_factor=1, ) tir.evaluate( - tir.test_intrin( + tir.intrin_test( sub_sub_A.data, sub_sub_A.elem_offset, sub_sub_A.strides[0], @@ -250,7 +250,7 @@ def transformed_recursive_match(a: ty.handle, b: ty.handle) -> None: ] ) tir.evaluate( - tir.test_intrin( + tir.intrin_test( A.data, i * 4096 + j * 1024 + jj * 256 + k * 16 + kk * 4, 64, @@ -282,7 +282,7 @@ def symbolic_match(a: ty.handle, b: ty.handle, n: ty.int32, m: ty.int32) -> None sub_A[ii, jj] = 1 for j in range(0, 4): tir.evaluate( - tir.test_intrin( + tir.intrin_test( sub_B.data, sub_B.elem_offset, sub_B.strides[0], @@ -306,7 +306,7 @@ def transformed_symbolic_match(a: ty.handle, b: ty.handle, n: ty.int32, m: ty.in A[i * m + ii, jj] = 1 for j in range(0, 4): tir.evaluate( - tir.test_intrin( + tir.intrin_test( B.data, i * n * (m * 4), m * 4, @@ -330,7 +330,7 @@ def rank0_buffer(a: ty.handle, b: ty.handle) -> None: sub_B = tir.match_buffer(B[i, j], (), offset_factor=1) sub_A[()] = 1 tir.evaluate( - tir.test_intrin( + tir.intrin_test( sub_B.data, sub_B.elem_offset, 0, @@ -352,7 +352,7 @@ def transformed_rank0_buffer(a: ty.handle, b: ty.handle) -> None: tir.writes([A[i, j], B[i, j]]) A[i, j] = 1 tir.evaluate( - tir.test_intrin( + tir.intrin_test( B.data, i * 8 + j, 0, diff --git a/tests/python/unittest/test_tir_nodes.py b/tests/python/unittest/test_tir_nodes.py index 07a82ba9936c7..dbae0b6fa516d 100644 --- a/tests/python/unittest/test_tir_nodes.py +++ b/tests/python/unittest/test_tir_nodes.py @@ -398,7 +398,7 @@ def test_block_blockrealize(): ) ] writes = [tvm.tir.BufferRegion(A, [tvm.ir.Range.from_min_extent(vx_var, 1)])] - match_buffer_region = tvm.tir.MatchBufferRegion( + block_match_buffer = tvm.tir.MatchBufferRegion( match_buffer, tvm.tir.BufferRegion(B, [tvm.ir.Range(0, 16), tvm.ir.Range(0, 16)]) ) @@ -410,7 +410,7 @@ def test_block_blockrealize(): body, init=init_body, alloc_buffers=[alloc_buffer], - match_buffers=[match_buffer_region], + match_buffers=[block_match_buffer], annotations={"attr_key": "attr_value"}, ) @@ -462,7 +462,7 @@ def test_block_blockrealize(): assert output.find("reads") != -1 assert output.find("writes") != -1 assert output.find("alloc_buffer") != -1 - assert output.find("match_buffer_region") != -1 + assert output.find("match_buffer") != -1 assert output.find("attr") != -1 assert output.find("with init()") != -1 @@ -471,7 +471,6 @@ def test_block_blockrealize(): test_intimm_cond() test_buffer_load_store() test_vars() - test_scoped_storage_var() test_prim_func() test_cast() test_attr() diff --git a/tests/python/unittest/test_tir_schedule_compute_inline.py b/tests/python/unittest/test_tir_schedule_compute_inline.py index c34ec8d610d61..0a33db09aef1c 100644 --- a/tests/python/unittest/test_tir_schedule_compute_inline.py +++ b/tests/python/unittest/test_tir_schedule_compute_inline.py @@ -171,7 +171,7 @@ def buffer_matched(a: ty.handle, c: ty.handle) -> None: with tir.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 with tir.block([128, 128], "C") as [vi, vj]: - Bb = tir.match_buffer_region(B[vi : vi + 1, vj]) + Bb = tir.match_buffer(B[vi : vi + 1, vj], (1, 1)) C[vi, vj] = Bb[0, 0] + 1.0 diff --git a/tests/python/unittest/test_tir_transform_compact_buffer_region.py b/tests/python/unittest/test_tir_transform_compact_buffer_region.py index 7c06b5ef5ca12..a469c6d0cc131 100644 --- a/tests/python/unittest/test_tir_transform_compact_buffer_region.py +++ b/tests/python/unittest/test_tir_transform_compact_buffer_region.py @@ -293,6 +293,52 @@ def compacted_complex_func(a: ty.handle, c: ty.handle, n: ty.int32) -> None: C[i, j] = B[0, j] +@tvm.script.tir +def match_buffer_func(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16)) + C = tir.match_buffer(c, (16, 16)) + for i in range(0, 16): + with tir.block([]): + A0 = tir.match_buffer(A[i, 0:16], (16)) + C0 = tir.match_buffer(C[i, 0:16], (16)) + B = tir.alloc_buffer((16, 16)) + with tir.block([]): + B0 = tir.match_buffer(B[i, 0:16], (16)) + for j in range(0, 16): + with tir.block([]) as []: + A1 = tir.match_buffer(A0[j], ()) + B1 = tir.match_buffer(B0[j], ()) + B1[()] = A1[()] + 1.0 + for j in range(0, 16): + with tir.block([]) as []: + C1 = tir.match_buffer(C0[j], ()) + B2 = tir.match_buffer(B[i, j], ()) + C1[()] = B2[()] * 2.0 + + +@tvm.script.tir +def compacted_match_buffer_func(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16)) + C = tir.match_buffer(c, (16, 16)) + for i in range(0, 16): + with tir.block([]): + A0 = tir.match_buffer(A[i, 0:16], (16)) + C0 = tir.match_buffer(C[i, 0:16], (16)) + B = tir.alloc_buffer((1, 16)) + with tir.block([]): + B0 = tir.match_buffer(B[0, 0:16], (16)) + for j in range(0, 16): + with tir.block([]) as []: + A1 = tir.match_buffer(A0[j], ()) + B1 = tir.match_buffer(B0[j], ()) + B1[()] = A1[()] + 1.0 + for j in range(0, 16): + with tir.block([]) as []: + C1 = tir.match_buffer(C0[j], ()) + B2 = tir.match_buffer(B[0, j], ()) + C1[()] = B2[()] * 2.0 + + def test_elementwise(): _check(elementwise_func, compacted_elementwise_func) @@ -321,6 +367,10 @@ def test_complex(): _check(complex_func, compacted_complex_func) +def test_match_buffer(): + _check(match_buffer_func, compacted_match_buffer_func) + + if __name__ == "__main__": test_elementwise() test_unschedulable_block() @@ -329,3 +379,4 @@ def test_complex(): test_warp_mem() test_symbolic() test_complex() + test_match_buffer() diff --git a/tests/python/unittest/test_tir_transform_lower_init_block.py b/tests/python/unittest/test_tir_transform_lower_init_block.py index 3fb8331d39fca..badf5e0e4d10d 100644 --- a/tests/python/unittest/test_tir_transform_lower_init_block.py +++ b/tests/python/unittest/test_tir_transform_lower_init_block.py @@ -18,6 +18,8 @@ from tvm import tir from tvm.script import ty +# pylint: disable=no-self-argument + @tvm.script.tir class WithInit: @@ -43,11 +45,46 @@ def main(a: ty.handle, b: ty.handle) -> None: B[i] += A[i, j, k] +@tvm.script.tir +class InitWithMatchBuffer: + def main(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [64, 64, 64]) + B = tir.match_buffer(b, [64]) + + with tir.block([64, tir.reduce_axis(0, 64), tir.reduce_axis(32, 64)]) as [i, j, k]: + BB = tir.match_buffer(B[i], ()) + AA = tir.match_buffer(A[i, 0:64, 0:64], (64, 64)) + with tir.init(): + BB[()] = tir.float32(0) + BB[()] += AA[j, k] + + +@tvm.script.tir +class BranchWithMatchBuffer: + def main(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [64, 64, 64]) + B = tir.match_buffer(b, [64]) + + with tir.block([64, tir.reduce_axis(0, 64), tir.reduce_axis(32, 64)]) as [i, j, k]: + BB = tir.match_buffer(B[i], ()) + AA = tir.match_buffer(A[i, 0:64, 0:64], (64, 64)) + if (j == 0) and (k == 32): + BB[()] = tir.float32(0) + BB[()] += AA[j, k] + + def test_lower_reduction(): origin_mod = WithInit() mod = tvm.tir.transform.LowerInitBlock()(origin_mod) tvm.ir.assert_structural_equal(mod, WithBranch(), True) +def test_lower_match_buffer(): + origin_mod = InitWithMatchBuffer() + mod = tvm.tir.transform.LowerInitBlock()(origin_mod) + tvm.ir.assert_structural_equal(mod, BranchWithMatchBuffer(), True) + + if __name__ == "__main__": test_lower_reduction() + test_lower_match_buffer() diff --git a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py index d42c5e1f8626d..022c964df0c78 100644 --- a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py +++ b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py @@ -115,6 +115,28 @@ def transformed_func() -> None: ) +@tvm.script.tir +def match_buffer_func() -> None: + C = tir.alloc_buffer((128, 128)) + with tir.block([128]) as [vi]: + C0 = tir.match_buffer(C[vi, 0:128], (128)) + with tir.block([128]) as [jj]: + C1 = tir.match_buffer(C0[jj], ()) + C1[()] = 0 + + +@tvm.script.tir +def transformed_match_buffer_func() -> None: + for i in range(0, 128): + with tir.block([128]) as [vi]: + tir.bind(vi, i) + C = tir.alloc_buffer((128, 128)) + C0 = tir.match_buffer(C[vi, 0:128], (128)) + with tir.block([128]) as [jj]: + C1 = tir.match_buffer(C0[jj], ()) + C1[()] = 0 + + def test_elementwise(): _check(element_func, transformed_element_func) @@ -123,6 +145,11 @@ def test_locate_buffer_allocation(): _check(original_func, transformed_func) +def test_match_buffer_allocation(): + _check(match_buffer_func, transformed_match_buffer_func) + + if __name__ == "__main__": test_elementwise() test_locate_buffer_allocation() + test_match_buffer_allocation() diff --git a/tests/python/unittest/test_tvmscript_complete.py b/tests/python/unittest/test_tvmscript_complete.py index a4d2dec0cce99..4798e9e098655 100644 --- a/tests/python/unittest/test_tvmscript_complete.py +++ b/tests/python/unittest/test_tvmscript_complete.py @@ -177,19 +177,6 @@ def test_complete_part_region(): _check_elementwise(func_with_part_access_region) -def test_complete_opaque_block_error(): - def render(e): - pass - - override_renderer(render) - - try: - from_source(func_with_opaque_block) - except tvm.error.DiagnosticError: - return - assert False - - @tvm.script.tir def func_with_bufferslice_indices(data: ty.handle, index: ty.handle) -> None: data_buf = tir.match_buffer(data, (16, 16), "float32") @@ -255,10 +242,46 @@ def test_complete_buffer_indices(): tvm.ir.assert_structural_equal(new_func, expected_recursive_bufferslice_indices) +@tvm.script.tir +def match_buffer_func(a: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16)) + for i in range(0, 16): + with tir.block([]): + A0 = tir.match_buffer(A[i, 0:16], (16)) + with tir.block([]): + for j in range(0, 16): + with tir.block([]) as []: + A1 = tir.match_buffer(A0[j], ()) + A1[()] = 1.0 + + +@tvm.script.tir +def expected_match_buffer_func(a: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16)) + for i in range(0, 16): + with tir.block([]): + tir.reads([]) + tir.writes(A[i, 0:16]) + A0 = tir.match_buffer(A[i, 0:16], (16)) + with tir.block([]): + tir.reads([]) + tir.writes(A0[0:16]) + for j in range(0, 16): + with tir.block([]) as []: + tir.reads([]) + tir.writes(A0[j]) + A1 = tir.match_buffer(A0[j], ()) + A1[()] = 1.0 + + +def test_complete_match_buffer(): + tvm.ir.assert_structural_equal(match_buffer_func, expected_match_buffer_func) + + if __name__ == "__main__": test_complete_matmul() test_complete_matmul_original() test_complete_with_root() - test_complete_opaque_block_error() test_complete_part_region() test_complete_buffer_indices() + test_complete_match_buffer() diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py index 71e0d1ba52e90..7aeceeccfa893 100644 --- a/tests/python/unittest/test_tvmscript_error_report.py +++ b/tests/python/unittest/test_tvmscript_error_report.py @@ -202,7 +202,7 @@ def test_inconsistent_grid(): def invalid_match_buffer_region() -> None: with tir.block([16, 16]) as [vi, vj]: - A = tir.match_buffer_region(vi) # error + A = tir.match_buffer(vi) # error tir.evaluate(1.0) @@ -431,4 +431,4 @@ def render(e): test_error_index_with_stop_slice() test_mismatch_args() test_tvm_exception_catch() - test_match_buffer_shape_mismatch() \ No newline at end of file + test_match_buffer_shape_mismatch()