Skip to content

Commit

Permalink
[TOPI][CUDA] Improve the performance of scatter_nd (apache#8479)
Browse files Browse the repository at this point in the history
* [TOPI][CUDA] Improve the performance of scatter_nd by:

1. Split into 2 kernels, one does the "Init" and another does the "Update".
   Thus they can have different Grid/Block configurations to better utilize
   SMs.
2. Use atomic_add instead of direct assignment, which could avoid the race
   condtion when multiple indices point to the same location of the output
   tensor. With this moidification, it's safe now to use more CUDA threads
   to gain more parallelism.

* Fix python code format.

* FIX: [TOPI][CUDA] Improve the performance of scatter_nd apache#8479

- Split ScatterND kernel into 2 sub-kernels using ib.new_scope()

- Replace ib.for_range() with blockIdx.y

- Using atomic_add when mode == "add"

- Keep threadIdx.x less than max_threads of GPU

* Comment added

* Add fallback implementation when "mode=add" meets int64

- Atomic_add from CUDA doesn't support int64 data type
- Change "ind{i}" to "ind%d"%i, where names of relay.var could correctly display

* Python format

* Fix line too long

* CI pass

* Empty, for CI pass

* Empty, for CI pass

* Empty, for CI pass

* Empty, for CI pass

* Empty, for CI pass

* Exchange blockIdx.x and blockIdx.y

* check for Vulkan or metal

* Fallback to previous algorithm when mode==update

* Update python/tvm/topi/cuda/scatter.py

Co-authored-by: Tristan Konolige <tristan.konolige@gmail.com>

* Assign TODO

* Swapping then and else block

Co-authored-by: wenxizhu <wenxizhu@tencent.com>
Co-authored-by: CaptainDuke <captainduke328@gmail.com>
Co-authored-by: Tristan Konolige <tristan.konolige@gmail.com>
  • Loading branch information
4 people authored and ylc committed Jan 13, 2022
1 parent 54a5539 commit 32df591
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 38 deletions.
122 changes: 85 additions & 37 deletions python/tvm/topi/cuda/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,9 +772,10 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
updates = ib.buffer_ptr(updates_ptr)
out = ib.buffer_ptr(out_ptr)

# We combine all the indices dimensions but the first one into a single
# dimension so we can iterate it in single loop instead of an arbitrary
# number of loops. We do the same thing for all the update dimensions.
atomic_add_return = ib.allocate(
updates.dtype, (1,), name="atomic_add_return", scope="local"
)

fused_indices_dimension = 1
for i in indices_ptr.shape[1:]:
fused_indices_dimension *= i
Expand All @@ -787,44 +788,91 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
for i in data_ptr.shape:
fused_shape *= i

# For now we avoid parallizing over dimensions indexed by `indices` as
# there may be repeated indices and hadling parallel accumulation can
# be hard. So we parallelize over X_M .. X_{N-1} instead. This will
# work well when these dimensions are large enough to saturate memory
# bandwidth, but performance will be bad when these dimensions are
# small.
bx = te.thread_axis("blockIdx.x")
tx = te.thread_axis("threadIdx.x")
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
tdim = min(max_threads, fused_updates_dimension)
ib.scope_attr(tx, "thread_extent", tdim)
bdim = ceil_div(fused_updates_dimension, tdim)
ib.scope_attr(bx, "thread_extent", bdim)

# Copy data into the output. This loop writes to the same portions of
# memory as the following loop, so we do not need a memory sync.
with ib.for_range(0, ceil_div(fused_shape, fused_updates_dimension), name="i") as i:
index = i * fused_updates_dimension + bx * tdim + tx
with ib.if_scope(bx * tdim + tx < fused_updates_dimension):

with ib.new_scope():
bdim = ceil_div(fused_shape, tdim)
bx = te.thread_axis("blockIdx.x")
tx = te.thread_axis("threadIdx.x")
ib.scope_attr(bx, "thread_extent", bdim)
ib.scope_attr(tx, "thread_extent", tdim)

index = bx * tdim + tx
with ib.if_scope(index < fused_shape):
out[index] = data[index]

with ib.for_range(0, fused_indices_dimension) as i:
j = bx * tdim + tx
with ib.if_scope(j < fused_updates_dimension):
offset = fused_updates_dimension
index = j # This is x_M, .. x_{N-1} part of the index into out.
# Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part
# of the index into out.
for l in reversed(range(indices_ptr.shape[0].value)):
# indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}]
index += offset * indices[i + l * fused_indices_dimension]
offset *= data_ptr.shape[l]
if mode == "update":
out[index] = updates[i * fused_updates_dimension + j]
elif mode == "add":
out[index] += updates[i * fused_updates_dimension + j]
else:
raise NotImplementedError("scatter_nd mode not in [update, add]:", mode)
# For better performance, we introduce blockIdx.y to implement for-loops
# within one thread.
# The code is parallel over the scattered indices, so we use atomic_add
# to guarantee correctness when mode=="add"

# For now, atomic is not supported by target "vulkan", "metal", or "cuda" with "int64"
# So we fallback to normal algorithm, using "+=" rather than atomic_add

# TODO (CaptainDuke):
# Since multiple threads compete for the same write index, which leads to
# non-determinstic output for update mode. We could add a new attribute,
# "allow_non_deterministic", which can be conditionally set to True by
# each frontend when non-determinsm is allowed.
cur_target_kind = str(tvm.target.Target.current(allow_none=False).kind)
with ib.new_scope():
if (
mode == "add"
and cur_target_kind not in ["vulkan", "metal"]
and updates.dtype in ["int32", "float32"]
):
bdim_x = fused_indices_dimension
bdim_y = ceil_div(fused_updates_dimension, tdim)
# In case of large input sizes, fused_indices_dimension might be too large.
# So we use blockIdx.x because holds larger scales.
bx = te.thread_axis("blockIdx.x")
by = te.thread_axis("blockIdx.y")
tx = te.thread_axis("threadIdx.x")
ib.scope_attr(bx, "thread_extent", bdim_x)
ib.scope_attr(by, "thread_extent", bdim_y)
ib.scope_attr(tx, "thread_extent", tdim)

j = by * tdim + tx
with ib.if_scope(j < fused_updates_dimension):
offset = fused_updates_dimension
index = j # This is x_M, .. x_{N-1} part of the index into out.
# Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}]
# part of the index into out.
up_index = bx * fused_updates_dimension + j
for l in reversed(range(indices_ptr.shape[0].value)):
# indices[bx * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}]
index += offset * indices[bx + l * fused_indices_dimension]
offset *= data_ptr.shape[l]
atomic_add_return[0] = atomic_add(
tvm.tir.call_intrin("handle", "tir.address_of", out[index]),
updates[up_index],
)
else:
bdim_x = ceil_div(fused_updates_dimension, tdim)
bx = te.thread_axis("blockIdx.x")
tx = te.thread_axis("threadIdx.x")
ib.scope_attr(bx, "thread_extent", bdim_x)
ib.scope_attr(tx, "thread_extent", tdim)
with ib.for_range(0, fused_indices_dimension) as i:
j = bx * tdim + tx
with ib.if_scope(j < fused_updates_dimension):
offset = fused_updates_dimension
index = j # This is x_M, .. x_{N-1} part of the index into out.
# Build up the
# indices[0, y_0, .. y_{K-1}], ... indices[M-1, y_0, .. y_{K-1}]
# part of the index into out.
for l in reversed(range(indices_ptr.shape[0].value)):
# indices[i * l * fused_indices_dimension] = indices[l, y_0,
# ... y_{k-1}]
index += offset * indices[i + l * fused_indices_dimension]
offset *= data_ptr.shape[l]
if mode == "update":
out[index] = updates[i * fused_updates_dimension + j]
elif mode == "add":
out[index] += updates[i * fused_updates_dimension + j]
else:
raise NotImplementedError("scatter_nd mode not in [update, add]:", mode)

return ib.get()

Expand Down
3 changes: 2 additions & 1 deletion tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1917,7 +1917,8 @@ def verify_scatter_nd_with_stack(
):
data = relay.var("data", shape=data_np.shape, dtype=str(data_np.dtype))
indices_vars = [
relay.var("ind{i}", shape=v.shape, dtype=str(v.dtype)) for i, v in enumerate(indices_np)
relay.var("ind%d" % i, shape=v.shape, dtype=str(v.dtype))
for i, v in enumerate(indices_np)
]
updates = relay.var("updates", shape=updates_np.shape, dtype=str(updates_np.dtype))

Expand Down

0 comments on commit 32df591

Please sign in to comment.