From ccb200564eb9b8d63e9ec4544b61f0aeff2d28a3 Mon Sep 17 00:00:00 2001 From: Shizhi Tang Date: Mon, 18 May 2020 10:55:05 +0800 Subject: [PATCH] [CUDA] Fix codegen for warp shuffle intrinsics (#5606) * fix shfl intrin * improve test_lower_warp_memory_cuda_half_a_warp --- src/target/source/intrin_rule_cuda.cc | 2 +- .../test_tir_transform_lower_warp_memory.py | 26 ++++++++++--------- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/src/target/source/intrin_rule_cuda.cc b/src/target/source/intrin_rule_cuda.cc index 4e4abd9764c3..7ebcfa6beb6f 100644 --- a/src/target/source/intrin_rule_cuda.cc +++ b/src/target/source/intrin_rule_cuda.cc @@ -116,7 +116,7 @@ static void DispatchCUDAShuffle(const TVMArgs& args, TVMRetValue* rv) { const CallNode* call = e.as(); CHECK(call != nullptr); CHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size - Array cuda_args{{call->args[0], call->args[1], call->args[2]}}; + Array cuda_args{{call->args[0], call->args[1], call->args[2], call->args[3]}}; const char* name = T()(call->dtype, call->name); *rv = CallNode::make(call->dtype, name, cuda_args, CallNode::PureExtern); } diff --git a/tests/python/unittest/test_tir_transform_lower_warp_memory.py b/tests/python/unittest/test_tir_transform_lower_warp_memory.py index bd553772e087..c3cf28982cdc 100644 --- a/tests/python/unittest/test_tir_transform_lower_warp_memory.py +++ b/tests/python/unittest/test_tir_transform_lower_warp_memory.py @@ -136,30 +136,32 @@ def check_cuda(dtype): print("Skip because gpu does not have fp16 support") return - m = 16 - A = te.placeholder((m,), name='A', dtype=dtype) - B = te.compute((m,), lambda i: A[(i + 1) % m], name='B') + n, m = 16, 16 + A = te.placeholder((n, m,), name='A', dtype=dtype) + B = te.compute((n, m,), lambda j, i: A[j, (i + 1) % m], name='B') cuda_target = tvm.target.create("cuda") assert cuda_target.thread_warp_size == 2 * m with cuda_target: s = te.create_schedule(B.op) tx = te.thread_axis("threadIdx.x") + ty = te.thread_axis("threadIdx.y") bx = te.thread_axis("blockIdx.x") AA = s.cache_read(A, "warp", [B]) - xo, xi = s[B].split(B.op.axis[0], nparts=1) - s[B].bind(xi, tx) - s[B].bind(xo, bx) - s[AA].compute_at(s[B], xo) - xo, xi = s[AA].split(s[AA].op.axis[0], nparts=1) - s[AA].bind(xo, bx) - s[AA].bind(xi, tx) + y, x = B.op.axis + z, y = s[B].split(y, nparts=2) + s[B].bind(x, tx) + s[B].bind(y, ty) + s[B].bind(z, bx) + s[AA].compute_at(s[B], y) + _, x = AA.op.axis + s[AA].bind(x, tx) ctx = tvm.gpu(0) func = tvm.build(s, [A, B], "cuda") - A_np = np.array(list(range(m)), dtype=dtype) - B_np = np.array(list(range(1, m)) + [0], dtype=dtype) + A_np = np.array([list(range(i, m + i)) for i in range(n)], dtype=dtype) + B_np = np.array([list(range(1 + i, m + i)) + [i] for i in range(n)], dtype=dtype) A_nd = tvm.nd.array(A_np, ctx) B_nd = tvm.nd.array(np.zeros(B_np.shape, dtype=B_np.dtype), ctx) func(A_nd, B_nd)