diff --git a/src/pass/storage_rewrite.cc b/src/pass/storage_rewrite.cc index 0170499e1491..877216ed7656 100644 --- a/src/pass/storage_rewrite.cc +++ b/src/pass/storage_rewrite.cc @@ -584,6 +584,12 @@ class StoragePlanRewriter : public IRMutator { e->new_alloc = Allocate::make( e->alloc_var, alloc_type, {combo_size}, const_true(), Evaluate::make(0)); + if (e->scope.tag.length() != 0) { + MemoryInfo info = GetMemoryInfo(e->scope.to_string()); + uint64_t total_elem = e->const_nbits / e->elem_type.bits(); + CHECK_LE(total_elem * e->elem_type.bits(), info->max_num_bits) + << "Allocation exceed bound of memory tag " << e->scope.to_string(); + } } } } diff --git a/tests/python/unittest/test_pass_storage_rewrite.py b/tests/python/unittest/test_pass_storage_rewrite.py index 2bb02998982f..3c07a1f26aff 100644 --- a/tests/python/unittest/test_pass_storage_rewrite.py +++ b/tests/python/unittest/test_pass_storage_rewrite.py @@ -28,15 +28,30 @@ def verify(n): tvm.ir_pass.PostOrderVisit(stmt, verify) assert num_alloc[0] == 1 +def register_mem(scope_tb, max_bits): + #Register mem + @tvm.register_func("tvm.info.mem.%s" % scope_tb) + def mem_info_inp_buffer(): + return tvm.make.node("MemoryInfo", + unit_bits= 16, + max_simd_bits=32, + max_num_bits=max_bits, + head_address=None) + def test_alloc_seq(): + scope_tb = "local.L0A" + max_bits = 1024 * 1024 * 1024 + + register_mem(scope_tb, max_bits) + ib = tvm.ir_builder.create() n = tvm.var("n") with ib.for_range(0, n, name="i") as i: with ib.for_range(0, 10, name="j") as j: - A = ib.allocate("float32", 200, name="A", scope="local.L0A") + A = ib.allocate("float32", 200, name="A", scope=scope_tb) A[j] = 1.2 with ib.for_range(0, 10, name="j") as j: - A = ib.allocate("float32", 200, name="B", scope="local.L0A") + A = ib.allocate("float32", 200, name="B", scope=scope_tb) A[j] = 1.3 body = ib.get() @@ -233,16 +248,9 @@ def test_parallel_alloc(): assert(isinstance(body.body.body.body.body, tvm.stmt.Allocate)) -def test_inplace_rule2(): +def test_inplace_rule2(scope_tb = "local_TB2", max_bits = 1024 * 1024 * 1024): #Test Buffer - scope_tb = "local_TB2" - @tvm.register_func("tvm.info.mem.%s" % scope_tb) - def mem_info_inp_buffer(): - return tvm.make.node("MemoryInfo", - unit_bits= 16, - max_simd_bits=32, - max_num_bits=1024*1024*1024, - head_address=None) + register_mem(scope_tb, max_bits) m = 10 A = tvm.placeholder((m,), name='A') C = tvm.placeholder((m,), name='C') @@ -275,16 +283,23 @@ def verify(n): tvm.ir_pass.PostOrderVisit(stmt, verify) assert num_alloc[0] == 2 +def test_exceed_mem(): + max_bits = 639 + # The critical max_num_bits is between 639 and 640 + loc = -1 + try: + test_inplace_rule2("local_TEM", max_bits) + except Exception as e: + estr = str(e) + loc = estr.find('Allocation exceed bound of memory') + assert loc != -1 + def test_inplace_rule3(): #Test Buffer scope_tb = "local_TB3" - @tvm.register_func("tvm.info.mem.%s" % scope_tb) - def mem_info_inp_buffer(): - return tvm.make.node("MemoryInfo", - unit_bits= 16, - max_simd_bits=32, - max_num_bits=1024*1024*1024, - head_address=None) + max_bits=1024 * 1024 * 1024 + + register_mem(scope_tb, max_bits) m = 10 B0 = tvm.placeholder((m,), name='B0') B1 = tvm.placeholder((m,), name='B1') @@ -388,17 +403,22 @@ def verify(n): assert num_alloc[0] == 1 def test_alloc_seq_type2(): + scope_tb = "local.L0A2" + max_bits=1024 * 1024 * 1024 + + register_mem(scope_tb, max_bits) + ib = tvm.ir_builder.create() n = tvm.var("n") with ib.for_range(0, n, name="i") as i: with ib.for_range(0, 10, name="j") as j: - A = ib.allocate("float32", 200, name="A", scope="local.L0A") + A = ib.allocate("float32", 200, name="A", scope=scope_tb) A[j] = 1.2 with ib.for_range(0, 20, name="j") as j: - B = ib.allocate("int16", 400, name="B", scope="local.L0A") + B = ib.allocate("int16", 400, name="B", scope=scope_tb) B[j] = tvm.const(1, "int16") with ib.for_range(0, 10, name="j") as j: - C = ib.allocate("float32", 200, name="C", scope="local.L0A") + C = ib.allocate("float32", 200, name="C", scope=scope_tb) C[j] = 1.2 body = ib.get() @@ -465,6 +485,7 @@ def test_replace_dataflow(): test_storage_combine() test_storage_share_gpu() test_inplace_rule2() + test_exceed_mem() test_inplace_rule3() test_alloc_seq_type() test_alloc_seq_type2()