diff --git a/src/pass/verify_gpu_code.cc b/src/pass/verify_gpu_code.cc index 363b7c4cf7cce..70908eb43d6b3 100644 --- a/src/pass/verify_gpu_code.cc +++ b/src/pass/verify_gpu_code.cc @@ -86,17 +86,29 @@ class GPUCodeVerifier : public IRVisitor { // record the number of threads in a block std::string name = var.get()->name_hint; if (name == "threadIdx.x" || name == "threadIdx.y" || name == "threadIdx.z") { + size_t length = static_cast(extent->value); if (!visited_threads_.count(name)) { visited_threads_.insert(name); - size_t length = static_cast(extent->value); thread_per_block_ *= length; if (name == "threadIdx.x") { valid_ &= length <= max_thread_x_; + thread_x_extent_ = length; } else if (name == "threadIdx.y") { valid_ &= length <= max_thread_y_; + thread_y_extent_ = length; } else if (name == "threadIdx.z") { valid_ &= length <= max_thread_z_; + thread_z_extent_ = length; + } + } else { + // the thread should be bound to axes with the same length + if (name == "threadIdx.x") { + valid_ &= length == thread_x_extent_; + } else if (name == "threadIdx.y") { + valid_ &= length == thread_y_extent_; + } else if (name == "threadIdx.z") { + valid_ &= length == thread_z_extent_; } } } @@ -111,6 +123,8 @@ class GPUCodeVerifier : public IRVisitor { std::unordered_set visited_shared_buffers_; std::unordered_set visited_threads_; + size_t thread_x_extent_, thread_y_extent_, thread_z_extent_; + size_t local_memory_per_block_; size_t shared_memory_per_block_; size_t thread_per_block_; diff --git a/tests/python/unittest/test_pass_verify_gpu_code.py b/tests/python/unittest/test_pass_verify_gpu_code.py index 6fc0387cf1446..e3884a727852e 100644 --- a/tests/python/unittest/test_pass_verify_gpu_code.py +++ b/tests/python/unittest/test_pass_verify_gpu_code.py @@ -162,8 +162,32 @@ def test_multiple_kernels(): tvm.build(s, [A, C], target) assert valid[0] +def test_wrong_bind(): + N = 1024 + + A = tvm.placeholder((N, N-1), name='A') + B = tvm.compute((N, N-1), lambda i, j: A[i, j]) + + s = tvm.create_schedule([B.op]) + + # bind a thread axis to two loop axes with different lengths + s[B].bind(s[B].op.axis[0], tvm.thread_axis("threadIdx.x")) + s[B].bind(s[B].op.axis[1], tvm.thread_axis("threadIdx.x")) + + for target in ['opencl', 'cuda']: + if not tvm.context(target).exist: + continue + + valid = [None] + with tvm.build_config(**{"add_lower_pass": [ + (2, get_verify_pass(valid, max_threads_per_block=N*N))]}): + tvm.build(s, [A, B], target) + assert not valid[0] + + if __name__ == "__main__": test_local_memory() test_shared_memory() test_num_thread() test_multiple_kernels() + test_wrong_bind()