-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TOPI][CUDA] Improve the performance of scatter_nd #8479
Changes from 2 commits
930043e
833561b
a6effec
675947e
1e1a617
afdd9e7
4f22477
fd573c5
a4373d0
d3fb5a2
1faa97a
92af183
7d940b0
c264949
c319e39
bac7b65
3cf534c
7c361c9
31fbde5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -787,42 +787,45 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): | |
for i in data_ptr.shape: | ||
fused_shape *= i | ||
|
||
# For now we avoid parallizing over dimensions indexed by `indices` as | ||
# there may be repeated indices and hadling parallel accumulation can | ||
# be hard. So we parallelize over X_M .. X_{N-1} instead. This will | ||
# work well when these dimensions are large enough to saturate memory | ||
# bandwidth, but performance will be bad when these dimensions are | ||
# small. | ||
bx = te.thread_axis("blockIdx.x") | ||
tx = te.thread_axis("threadIdx.x") | ||
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) | ||
tdim = min(max_threads, fused_updates_dimension) | ||
ib.scope_attr(tx, "thread_extent", tdim) | ||
bdim = ceil_div(fused_updates_dimension, tdim) | ||
ib.scope_attr(bx, "thread_extent", bdim) | ||
|
||
# Copy data into the output. This loop writes to the same portions of | ||
# memory as the following loop, so we do not need a memory sync. | ||
with ib.for_range(0, ceil_div(fused_shape, fused_updates_dimension), name="i") as i: | ||
index = i * fused_updates_dimension + bx * tdim + tx | ||
with ib.if_scope(bx * tdim + tx < fused_updates_dimension): | ||
# Init output tensor. | ||
with ib.new_scope(): | ||
bidx = te.thread_axis("blockIdx.x") | ||
tidx = te.thread_axis("threadIdx.x") | ||
gridDim = 1 | ||
for i in data_ptr.shape[:-1]: | ||
gridDim *= i | ||
blockDim = data_ptr.shape[-1] | ||
|
||
ib.scope_attr(bidx, "thread_extent", gridDim) | ||
ib.scope_attr(tidx, "thread_extent", blockDim) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In some cases this dimension will be very small. Can you instead split the full shape by max_num_threads? |
||
index = bidx * blockDim + tidx | ||
with ib.if_scope(index < fused_shape): | ||
out[index] = data[index] | ||
|
||
with ib.for_range(0, fused_indices_dimension) as i: | ||
j = bx * tdim + tx | ||
# Update output tensor by given values. | ||
with ib.new_scope(): | ||
bidx = te.thread_axis("blockIdx.x") | ||
tidx = te.thread_axis("threadIdx.x") | ||
gridDim = fused_indices_dimension # 32 * 600 = 19200 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove this comment There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed. |
||
blockDim = fused_updates_dimension | ||
ib.scope_attr(bidx, "thread_extent", gridDim) | ||
ib.scope_attr(tidx, "thread_extent", blockDim) | ||
|
||
j = tidx | ||
with ib.if_scope(j < fused_updates_dimension): | ||
offset = fused_updates_dimension | ||
index = j # This is x_M, .. x_{N-1} part of the index into out. | ||
# Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part | ||
# of the index into out. | ||
Comment on lines
-815
to
-817
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you keep this comment. I believe it still holds There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added |
||
for l in reversed(range(indices_ptr.shape[0].value)): | ||
findex = j | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You've set There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed |
||
for l in reversed(range(indices_ptr.shape[0].value)): # 2, 1, 0 | ||
# indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}] | ||
index += offset * indices[i + l * fused_indices_dimension] | ||
findex += offset * indices[bidx + l * fused_indices_dimension] | ||
offset *= data_ptr.shape[l] | ||
if mode == "update": | ||
out[index] = updates[i * fused_updates_dimension + j] | ||
out[findex] = updates[bidx * fused_updates_dimension + tidx] | ||
elif mode == "add": | ||
out[index] += updates[i * fused_updates_dimension + j] | ||
out[findex] = atomic_add( | ||
tvm.tir.call_intrin("handle", "tir.address_of", out[findex]), | ||
updates[bidx * fused_updates_dimension + j], | ||
) | ||
else: | ||
raise NotImplementedError("scatter_nd mode not in [update, add]:", mode) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a comment about how we are doing parallelism (we are thread-parallel over all the update dimension and each block handles one set of indices?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We follow the original parallelism scheme, but replace
ib.for_range()
withblockIdx.y
.Atomic_add
guarantees correctness whenmode=="add"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you update the comment in the code to reflect this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added