Skip to content

Commit

Permalink
#1592 [PASS] Fix missing mem CHECK in storage_rewrite (#1616)
Browse files Browse the repository at this point in the history
  • Loading branch information
xqdan authored and tqchen committed Aug 18, 2018
1 parent 9b0e499 commit 38d0835
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 21 deletions.
6 changes: 6 additions & 0 deletions src/pass/storage_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,12 @@ class StoragePlanRewriter : public IRMutator {
e->new_alloc = Allocate::make(
e->alloc_var, alloc_type, {combo_size}, const_true(),
Evaluate::make(0));
if (e->scope.tag.length() != 0) {
MemoryInfo info = GetMemoryInfo(e->scope.to_string());
uint64_t total_elem = e->const_nbits / e->elem_type.bits();
CHECK_LE(total_elem * e->elem_type.bits(), info->max_num_bits)
<< "Allocation exceed bound of memory tag " << e->scope.to_string();
}
}
}
}
Expand Down
63 changes: 42 additions & 21 deletions tests/python/unittest/test_pass_storage_rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,30 @@ def verify(n):
tvm.ir_pass.PostOrderVisit(stmt, verify)
assert num_alloc[0] == 1

def register_mem(scope_tb, max_bits):
#Register mem
@tvm.register_func("tvm.info.mem.%s" % scope_tb)
def mem_info_inp_buffer():
return tvm.make.node("MemoryInfo",
unit_bits= 16,
max_simd_bits=32,
max_num_bits=max_bits,
head_address=None)

def test_alloc_seq():
scope_tb = "local.L0A"
max_bits = 1024 * 1024 * 1024

register_mem(scope_tb, max_bits)

ib = tvm.ir_builder.create()
n = tvm.var("n")
with ib.for_range(0, n, name="i") as i:
with ib.for_range(0, 10, name="j") as j:
A = ib.allocate("float32", 200, name="A", scope="local.L0A")
A = ib.allocate("float32", 200, name="A", scope=scope_tb)
A[j] = 1.2
with ib.for_range(0, 10, name="j") as j:
A = ib.allocate("float32", 200, name="B", scope="local.L0A")
A = ib.allocate("float32", 200, name="B", scope=scope_tb)
A[j] = 1.3

body = ib.get()
Expand Down Expand Up @@ -233,16 +248,9 @@ def test_parallel_alloc():

assert(isinstance(body.body.body.body.body, tvm.stmt.Allocate))

def test_inplace_rule2():
def test_inplace_rule2(scope_tb = "local_TB2", max_bits = 1024 * 1024 * 1024):
#Test Buffer
scope_tb = "local_TB2"
@tvm.register_func("tvm.info.mem.%s" % scope_tb)
def mem_info_inp_buffer():
return tvm.make.node("MemoryInfo",
unit_bits= 16,
max_simd_bits=32,
max_num_bits=1024*1024*1024,
head_address=None)
register_mem(scope_tb, max_bits)
m = 10
A = tvm.placeholder((m,), name='A')
C = tvm.placeholder((m,), name='C')
Expand Down Expand Up @@ -275,16 +283,23 @@ def verify(n):
tvm.ir_pass.PostOrderVisit(stmt, verify)
assert num_alloc[0] == 2

def test_exceed_mem():
max_bits = 639
# The critical max_num_bits is between 639 and 640
loc = -1
try:
test_inplace_rule2("local_TEM", max_bits)
except Exception as e:
estr = str(e)
loc = estr.find('Allocation exceed bound of memory')
assert loc != -1

def test_inplace_rule3():
#Test Buffer
scope_tb = "local_TB3"
@tvm.register_func("tvm.info.mem.%s" % scope_tb)
def mem_info_inp_buffer():
return tvm.make.node("MemoryInfo",
unit_bits= 16,
max_simd_bits=32,
max_num_bits=1024*1024*1024,
head_address=None)
max_bits=1024 * 1024 * 1024

register_mem(scope_tb, max_bits)
m = 10
B0 = tvm.placeholder((m,), name='B0')
B1 = tvm.placeholder((m,), name='B1')
Expand Down Expand Up @@ -388,17 +403,22 @@ def verify(n):
assert num_alloc[0] == 1

def test_alloc_seq_type2():
scope_tb = "local.L0A2"
max_bits=1024 * 1024 * 1024

register_mem(scope_tb, max_bits)

ib = tvm.ir_builder.create()
n = tvm.var("n")
with ib.for_range(0, n, name="i") as i:
with ib.for_range(0, 10, name="j") as j:
A = ib.allocate("float32", 200, name="A", scope="local.L0A")
A = ib.allocate("float32", 200, name="A", scope=scope_tb)
A[j] = 1.2
with ib.for_range(0, 20, name="j") as j:
B = ib.allocate("int16", 400, name="B", scope="local.L0A")
B = ib.allocate("int16", 400, name="B", scope=scope_tb)
B[j] = tvm.const(1, "int16")
with ib.for_range(0, 10, name="j") as j:
C = ib.allocate("float32", 200, name="C", scope="local.L0A")
C = ib.allocate("float32", 200, name="C", scope=scope_tb)
C[j] = 1.2

body = ib.get()
Expand Down Expand Up @@ -465,6 +485,7 @@ def test_replace_dataflow():
test_storage_combine()
test_storage_share_gpu()
test_inplace_rule2()
test_exceed_mem()
test_inplace_rule3()
test_alloc_seq_type()
test_alloc_seq_type2()
Expand Down

0 comments on commit 38d0835

Please sign in to comment.