Skip to content

Commit

Permalink
[PASS] Enhance gpu verify pass (apache#1660)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored and tqchen committed Aug 30, 2018
1 parent 4e6740a commit 8d3d4c4
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 1 deletion.
16 changes: 15 additions & 1 deletion src/pass/verify_gpu_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,29 @@ class GPUCodeVerifier : public IRVisitor {
// record the number of threads in a block
std::string name = var.get()->name_hint;
if (name == "threadIdx.x" || name == "threadIdx.y" || name == "threadIdx.z") {
size_t length = static_cast<size_t>(extent->value);
if (!visited_threads_.count(name)) {
visited_threads_.insert(name);
size_t length = static_cast<size_t>(extent->value);
thread_per_block_ *= length;

if (name == "threadIdx.x") {
valid_ &= length <= max_thread_x_;
thread_x_extent_ = length;
} else if (name == "threadIdx.y") {
valid_ &= length <= max_thread_y_;
thread_y_extent_ = length;
} else if (name == "threadIdx.z") {
valid_ &= length <= max_thread_z_;
thread_z_extent_ = length;
}
} else {
// the thread should be bound to axes with the same length
if (name == "threadIdx.x") {
valid_ &= length == thread_x_extent_;
} else if (name == "threadIdx.y") {
valid_ &= length == thread_y_extent_;
} else if (name == "threadIdx.z") {
valid_ &= length == thread_z_extent_;
}
}
}
Expand All @@ -111,6 +123,8 @@ class GPUCodeVerifier : public IRVisitor {
std::unordered_set<const tvm::Variable *> visited_shared_buffers_;
std::unordered_set<std::string> visited_threads_;

size_t thread_x_extent_, thread_y_extent_, thread_z_extent_;

size_t local_memory_per_block_;
size_t shared_memory_per_block_;
size_t thread_per_block_;
Expand Down
24 changes: 24 additions & 0 deletions tests/python/unittest/test_pass_verify_gpu_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,32 @@ def test_multiple_kernels():
tvm.build(s, [A, C], target)
assert valid[0]

def test_wrong_bind():
N = 1024

A = tvm.placeholder((N, N-1), name='A')
B = tvm.compute((N, N-1), lambda i, j: A[i, j])

s = tvm.create_schedule([B.op])

# bind a thread axis to two loop axes with different lengths
s[B].bind(s[B].op.axis[0], tvm.thread_axis("threadIdx.x"))
s[B].bind(s[B].op.axis[1], tvm.thread_axis("threadIdx.x"))

for target in ['opencl', 'cuda']:
if not tvm.context(target).exist:
continue

valid = [None]
with tvm.build_config(**{"add_lower_pass": [
(2, get_verify_pass(valid, max_threads_per_block=N*N))]}):
tvm.build(s, [A, B], target)
assert not valid[0]


if __name__ == "__main__":
test_local_memory()
test_shared_memory()
test_num_thread()
test_multiple_kernels()
test_wrong_bind()

0 comments on commit 8d3d4c4

Please sign in to comment.