Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#1592 [PASS] Fix missing mem CHECK in storage_rewrite #1616

Merged
merged 7 commits into from
Aug 18, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/pass/storage_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,12 @@ class StoragePlanRewriter : public IRMutator {
e->new_alloc = Allocate::make(
e->alloc_var, alloc_type, {combo_size}, const_true(),
Evaluate::make(0));
if (e->scope.tag.length() != 0) {
MemoryInfo info = GetMemoryInfo(e->scope.to_string());
uint64_t total_elem = e->const_nbits / e->elem_type.bits();
CHECK_LE(total_elem * e->elem_type.bits(), info->max_num_bits)
<< "Allocation exceed bound of memory tag " << e->scope.to_string();
}
}
}
}
Expand Down
63 changes: 42 additions & 21 deletions tests/python/unittest/test_pass_storage_rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,30 @@ def verify(n):
tvm.ir_pass.PostOrderVisit(stmt, verify)
assert num_alloc[0] == 1

def register_mem(scope_tb, max_bits):
#Register mem
@tvm.register_func("tvm.info.mem.%s" % scope_tb)
def mem_info_inp_buffer():
return tvm.make.node("MemoryInfo",
unit_bits= 16,
max_simd_bits=32,
max_num_bits=max_bits,
head_address=None)

def test_alloc_seq():
scope_tb = "local.L0A"
max_bits = 1024 * 1024 * 1024

register_mem(scope_tb, max_bits)

ib = tvm.ir_builder.create()
n = tvm.var("n")
with ib.for_range(0, n, name="i") as i:
with ib.for_range(0, 10, name="j") as j:
A = ib.allocate("float32", 200, name="A", scope="local.L0A")
A = ib.allocate("float32", 200, name="A", scope=scope_tb)
A[j] = 1.2
with ib.for_range(0, 10, name="j") as j:
A = ib.allocate("float32", 200, name="B", scope="local.L0A")
A = ib.allocate("float32", 200, name="B", scope=scope_tb)
A[j] = 1.3

body = ib.get()
Expand Down Expand Up @@ -233,16 +248,9 @@ def test_parallel_alloc():

assert(isinstance(body.body.body.body.body, tvm.stmt.Allocate))

def test_inplace_rule2():
def test_inplace_rule2(scope_tb = "local_TB2", max_bits = 1024 * 1024 * 1024):
#Test Buffer
scope_tb = "local_TB2"
@tvm.register_func("tvm.info.mem.%s" % scope_tb)
def mem_info_inp_buffer():
return tvm.make.node("MemoryInfo",
unit_bits= 16,
max_simd_bits=32,
max_num_bits=1024*1024*1024,
head_address=None)
register_mem(scope_tb, max_bits)
m = 10
A = tvm.placeholder((m,), name='A')
C = tvm.placeholder((m,), name='C')
Expand Down Expand Up @@ -275,16 +283,23 @@ def verify(n):
tvm.ir_pass.PostOrderVisit(stmt, verify)
assert num_alloc[0] == 2

def test_exceed_mem():
max_bits = 639
# The critical max_num_bits is between 639 and 640
loc = -1
try:
test_inplace_rule2("local_TEM", max_bits)
except Exception as e:
estr = str(e)
loc = estr.find('Allocation exceed bound of memory')
assert loc != -1

def test_inplace_rule3():
#Test Buffer
scope_tb = "local_TB3"
@tvm.register_func("tvm.info.mem.%s" % scope_tb)
def mem_info_inp_buffer():
return tvm.make.node("MemoryInfo",
unit_bits= 16,
max_simd_bits=32,
max_num_bits=1024*1024*1024,
head_address=None)
max_bits=1024 * 1024 * 1024

register_mem(scope_tb, max_bits)
m = 10
B0 = tvm.placeholder((m,), name='B0')
B1 = tvm.placeholder((m,), name='B1')
Expand Down Expand Up @@ -388,17 +403,22 @@ def verify(n):
assert num_alloc[0] == 1

def test_alloc_seq_type2():
scope_tb = "local.L0A2"
max_bits=1024 * 1024 * 1024

register_mem(scope_tb, max_bits)

ib = tvm.ir_builder.create()
n = tvm.var("n")
with ib.for_range(0, n, name="i") as i:
with ib.for_range(0, 10, name="j") as j:
A = ib.allocate("float32", 200, name="A", scope="local.L0A")
A = ib.allocate("float32", 200, name="A", scope=scope_tb)
A[j] = 1.2
with ib.for_range(0, 20, name="j") as j:
B = ib.allocate("int16", 400, name="B", scope="local.L0A")
B = ib.allocate("int16", 400, name="B", scope=scope_tb)
B[j] = tvm.const(1, "int16")
with ib.for_range(0, 10, name="j") as j:
C = ib.allocate("float32", 200, name="C", scope="local.L0A")
C = ib.allocate("float32", 200, name="C", scope=scope_tb)
C[j] = 1.2

body = ib.get()
Expand Down Expand Up @@ -465,6 +485,7 @@ def test_replace_dataflow():
test_storage_combine()
test_storage_share_gpu()
test_inplace_rule2()
test_exceed_mem()
test_inplace_rule3()
test_alloc_seq_type()
test_alloc_seq_type2()
Expand Down