Skip to content

Commit

Permalink
[TOPI][CUDA] Improve the performance of scatter_nd by:
Browse files Browse the repository at this point in the history
1. Split into 2 kernels, one does the "Init" and another does the "Update".
   Thus they can have different Grid/Block configurations to better utilize
   SMs.
2. Use atomic_add instead of direct assignment, which could avoid the race
   condtion when multiple indices point to the same location of the output
   tensor. With this moidification, it's safe now to use more CUDA threads
   to gain more parallelism.
  • Loading branch information
wenxizhu committed Jul 15, 2021
1 parent 1a9bcc5 commit 930043e
Showing 1 changed file with 28 additions and 28 deletions.
56 changes: 28 additions & 28 deletions python/tvm/topi/cuda/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,42 +787,42 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
for i in data_ptr.shape:
fused_shape *= i

# For now we avoid parallizing over dimensions indexed by `indices` as
# there may be repeated indices and hadling parallel accumulation can
# be hard. So we parallelize over X_M .. X_{N-1} instead. This will
# work well when these dimensions are large enough to saturate memory
# bandwidth, but performance will be bad when these dimensions are
# small.
bx = te.thread_axis("blockIdx.x")
tx = te.thread_axis("threadIdx.x")
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
tdim = min(max_threads, fused_updates_dimension)
ib.scope_attr(tx, "thread_extent", tdim)
bdim = ceil_div(fused_updates_dimension, tdim)
ib.scope_attr(bx, "thread_extent", bdim)

# Copy data into the output. This loop writes to the same portions of
# memory as the following loop, so we do not need a memory sync.
with ib.for_range(0, ceil_div(fused_shape, fused_updates_dimension), name="i") as i:
index = i * fused_updates_dimension + bx * tdim + tx
with ib.if_scope(bx * tdim + tx < fused_updates_dimension):
# Init output tensor.
with ib.new_scope():
bidx = te.thread_axis("blockIdx.x")
tidx = te.thread_axis("threadIdx.x")
gridDim = 1
for i in data_ptr.shape[:-1]:
gridDim *= i
blockDim = data_ptr.shape[-1]

ib.scope_attr(bidx, "thread_extent", gridDim)
ib.scope_attr(tidx, "thread_extent", blockDim)
index = bidx * blockDim + tidx
with ib.if_scope(index < fused_shape):
out[index] = data[index]

with ib.for_range(0, fused_indices_dimension) as i:
j = bx * tdim + tx
# Update output tensor by given values.
with ib.new_scope():
bidx = te.thread_axis("blockIdx.x")
tidx = te.thread_axis("threadIdx.x")
gridDim = fused_indices_dimension # 32 * 600 = 19200
blockDim = fused_updates_dimension
ib.scope_attr(bidx, "thread_extent", gridDim)
ib.scope_attr(tidx, "thread_extent", blockDim)

j = tidx
with ib.if_scope(j < fused_updates_dimension):
offset = fused_updates_dimension
index = j # This is x_M, .. x_{N-1} part of the index into out.
# Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part
# of the index into out.
for l in reversed(range(indices_ptr.shape[0].value)):
findex = j
for l in reversed(range(indices_ptr.shape[0].value)): # 2, 1, 0
# indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}]
index += offset * indices[i + l * fused_indices_dimension]
findex += offset * indices[bidx + l * fused_indices_dimension]
offset *= data_ptr.shape[l]
if mode == "update":
out[index] = updates[i * fused_updates_dimension + j]
out[findex] = updates[bidx * fused_updates_dimension + tidx]
elif mode == "add":
out[index] += updates[i * fused_updates_dimension + j]
out[findex] = atomic_add(tvm.tir.call_intrin("handle", "tir.address_of", out[findex]), updates[bidx * fused_updates_dimension + j])
else:
raise NotImplementedError("scatter_nd mode not in [update, add]:", mode)

Expand Down

0 comments on commit 930043e

Please sign in to comment.