From 930043e8b7a7aa4edd4afa35886f6d653272e88b Mon Sep 17 00:00:00 2001 From: wenxizhu Date: Thu, 15 Jul 2021 17:30:44 +0800 Subject: [PATCH] [TOPI][CUDA] Improve the performance of scatter_nd by: 1. Split into 2 kernels, one does the "Init" and another does the "Update". Thus they can have different Grid/Block configurations to better utilize SMs. 2. Use atomic_add instead of direct assignment, which could avoid the race condtion when multiple indices point to the same location of the output tensor. With this moidification, it's safe now to use more CUDA threads to gain more parallelism. --- python/tvm/topi/cuda/scatter.py | 56 ++++++++++++++++----------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index c697b648786e..4808a3293523 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -787,42 +787,42 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): for i in data_ptr.shape: fused_shape *= i - # For now we avoid parallizing over dimensions indexed by `indices` as - # there may be repeated indices and hadling parallel accumulation can - # be hard. So we parallelize over X_M .. X_{N-1} instead. This will - # work well when these dimensions are large enough to saturate memory - # bandwidth, but performance will be bad when these dimensions are - # small. - bx = te.thread_axis("blockIdx.x") - tx = te.thread_axis("threadIdx.x") - max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) - tdim = min(max_threads, fused_updates_dimension) - ib.scope_attr(tx, "thread_extent", tdim) - bdim = ceil_div(fused_updates_dimension, tdim) - ib.scope_attr(bx, "thread_extent", bdim) - - # Copy data into the output. This loop writes to the same portions of - # memory as the following loop, so we do not need a memory sync. - with ib.for_range(0, ceil_div(fused_shape, fused_updates_dimension), name="i") as i: - index = i * fused_updates_dimension + bx * tdim + tx - with ib.if_scope(bx * tdim + tx < fused_updates_dimension): + # Init output tensor. + with ib.new_scope(): + bidx = te.thread_axis("blockIdx.x") + tidx = te.thread_axis("threadIdx.x") + gridDim = 1 + for i in data_ptr.shape[:-1]: + gridDim *= i + blockDim = data_ptr.shape[-1] + + ib.scope_attr(bidx, "thread_extent", gridDim) + ib.scope_attr(tidx, "thread_extent", blockDim) + index = bidx * blockDim + tidx + with ib.if_scope(index < fused_shape): out[index] = data[index] - with ib.for_range(0, fused_indices_dimension) as i: - j = bx * tdim + tx + # Update output tensor by given values. + with ib.new_scope(): + bidx = te.thread_axis("blockIdx.x") + tidx = te.thread_axis("threadIdx.x") + gridDim = fused_indices_dimension # 32 * 600 = 19200 + blockDim = fused_updates_dimension + ib.scope_attr(bidx, "thread_extent", gridDim) + ib.scope_attr(tidx, "thread_extent", blockDim) + + j = tidx with ib.if_scope(j < fused_updates_dimension): offset = fused_updates_dimension - index = j # This is x_M, .. x_{N-1} part of the index into out. - # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part - # of the index into out. - for l in reversed(range(indices_ptr.shape[0].value)): + findex = j + for l in reversed(range(indices_ptr.shape[0].value)): # 2, 1, 0 # indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}] - index += offset * indices[i + l * fused_indices_dimension] + findex += offset * indices[bidx + l * fused_indices_dimension] offset *= data_ptr.shape[l] if mode == "update": - out[index] = updates[i * fused_updates_dimension + j] + out[findex] = updates[bidx * fused_updates_dimension + tidx] elif mode == "add": - out[index] += updates[i * fused_updates_dimension + j] + out[findex] = atomic_add(tvm.tir.call_intrin("handle", "tir.address_of", out[findex]), updates[bidx * fused_updates_dimension + j]) else: raise NotImplementedError("scatter_nd mode not in [update, add]:", mode)