Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI][CUDA] Improve the performance of scatter_nd #8479

Merged
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 31 additions & 28 deletions python/tvm/topi/cuda/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,42 +787,45 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
for i in data_ptr.shape:
fused_shape *= i

# For now we avoid parallizing over dimensions indexed by `indices` as
# there may be repeated indices and hadling parallel accumulation can
# be hard. So we parallelize over X_M .. X_{N-1} instead. This will
# work well when these dimensions are large enough to saturate memory
# bandwidth, but performance will be bad when these dimensions are
# small.
Comment on lines -790 to -795
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment about how we are doing parallelism (we are thread-parallel over all the update dimension and each block handles one set of indices?)

Copy link
Contributor

@CaptainDuke CaptainDuke Jul 20, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We follow the original parallelism scheme, but replace ib.for_range() with blockIdx.y.
Atomic_add guarantees correctness when mode=="add"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update the comment in the code to reflect this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

bx = te.thread_axis("blockIdx.x")
tx = te.thread_axis("threadIdx.x")
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
tdim = min(max_threads, fused_updates_dimension)
ib.scope_attr(tx, "thread_extent", tdim)
bdim = ceil_div(fused_updates_dimension, tdim)
ib.scope_attr(bx, "thread_extent", bdim)

# Copy data into the output. This loop writes to the same portions of
# memory as the following loop, so we do not need a memory sync.
with ib.for_range(0, ceil_div(fused_shape, fused_updates_dimension), name="i") as i:
index = i * fused_updates_dimension + bx * tdim + tx
with ib.if_scope(bx * tdim + tx < fused_updates_dimension):
# Init output tensor.
with ib.new_scope():
bidx = te.thread_axis("blockIdx.x")
tidx = te.thread_axis("threadIdx.x")
gridDim = 1
for i in data_ptr.shape[:-1]:
gridDim *= i
blockDim = data_ptr.shape[-1]

ib.scope_attr(bidx, "thread_extent", gridDim)
ib.scope_attr(tidx, "thread_extent", blockDim)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In some cases this dimension will be very small. Can you instead split the full shape by max_num_threads?

index = bidx * blockDim + tidx
with ib.if_scope(index < fused_shape):
out[index] = data[index]

with ib.for_range(0, fused_indices_dimension) as i:
j = bx * tdim + tx
# Update output tensor by given values.
with ib.new_scope():
bidx = te.thread_axis("blockIdx.x")
tidx = te.thread_axis("threadIdx.x")
gridDim = fused_indices_dimension # 32 * 600 = 19200
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this comment

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

blockDim = fused_updates_dimension
ib.scope_attr(bidx, "thread_extent", gridDim)
ib.scope_attr(tidx, "thread_extent", blockDim)

j = tidx
with ib.if_scope(j < fused_updates_dimension):
offset = fused_updates_dimension
index = j # This is x_M, .. x_{N-1} part of the index into out.
# Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part
# of the index into out.
Comment on lines -815 to -817
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you keep this comment. I believe it still holds

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

for l in reversed(range(indices_ptr.shape[0].value)):
findex = j
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You've set j = tidx and then only use it in one spot. Why not just use tidx everywhere?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

for l in reversed(range(indices_ptr.shape[0].value)): # 2, 1, 0
# indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}]
index += offset * indices[i + l * fused_indices_dimension]
findex += offset * indices[bidx + l * fused_indices_dimension]
offset *= data_ptr.shape[l]
if mode == "update":
out[index] = updates[i * fused_updates_dimension + j]
out[findex] = updates[bidx * fused_updates_dimension + tidx]
elif mode == "add":
out[index] += updates[i * fused_updates_dimension + j]
out[findex] = atomic_add(
tvm.tir.call_intrin("handle", "tir.address_of", out[findex]),
updates[bidx * fused_updates_dimension + j],
)
else:
raise NotImplementedError("scatter_nd mode not in [update, add]:", mode)

Expand Down