Skip to content

Commit

Permalink
[TOPI] GPU scatter 1D via sorting based approach (#7056)
Browse files Browse the repository at this point in the history
* add thrust stable sort

* rename

* scatter via sort working

* correctly handles negative indices

* clean up, add some comments

* add doc string

* remove scatter benchmark stuff

* add more doc

* fix typo

* lint fix

* silence lint

* fix py format

* check for thrust availablity before test

Co-authored-by: masa <masa@pop-os.localdomain>
  • Loading branch information
masahi and masa authored Dec 9, 2020
1 parent 0095b21 commit 465cd14
Show file tree
Hide file tree
Showing 5 changed files with 271 additions and 2 deletions.
1 change: 1 addition & 0 deletions cmake/modules/CUDA.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ if(USE_CUDA)
message(STATUS "Build with Thrust support")
cmake_minimum_required(VERSION 3.13) # to compile CUDA code
enable_language(CUDA)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --extended-lambda")
file(GLOB CONTRIB_THRUST_SRC src/runtime/contrib/thrust/*.cu)
list(APPEND RUNTIME_SRCS ${CONTRIB_THRUST_SRC})
endif(USE_THRUST)
Expand Down
106 changes: 105 additions & 1 deletion python/tvm/topi/cuda/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tvm import te
from ..scatter import _verify_scatter_nd_inputs
from .nms import atomic_add
from .sort import stable_sort_by_key_thrust, is_thrust_available


def ceil_div(a, b):
Expand Down Expand Up @@ -416,6 +417,97 @@ def gen_ir_4d(data, indices, updates, axis, out, update_func):
return ib.get()


def gen_scatter_1d_thrust(data, indices_sorted, updates_sorted, axis, out, _):
"""Generate scatter ir for 1d inputs, using a sorting based approach.
By sorting indices and comparing neighboring two indices, we can tell which
of elements in the indices tensor can scatter its update value into the output.
Sorting of indices, and sorting of updates with respect to indices, can be done
at the same time by thrust's sort_by_key function. It is important that sorting
be done in a "stable" way via stable_sort, to guarantee deterministic output.
Parameters
----------
data : tir.Tensor
The input data to the operator.
indices_sorted : tir.Tensor
The sorted index locations to update.
updates : tir.Tensor
The values to update, sorted by indices.
axis : int
The axis to scatter on. It must be 0 for this function.
out : tir.Tensor
The output tensor.
Returns
-------
ret : tir
The computational ir.
"""
assert axis == 0
n = data.shape[0]

ib = tvm.tir.ir_builder.create()

out_ptr = ib.buffer_ptr(out)
data_ptr = ib.buffer_ptr(data)

max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
nthread_tx = max_threads

with ib.new_scope():
nthread_bx = ceil_div(n, nthread_tx)
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * nthread_tx + tx
with ib.if_scope(tid < n):
out_ptr[tid] = data_ptr[tid]

indices_ptr = ib.buffer_ptr(indices_sorted)
updates_ptr = ib.buffer_ptr(updates_sorted)

ni = indices_sorted.shape[0]

def do_update(ib, index, update):
with ib.if_scope(index < 0):
out_ptr[index + n] = update
with ib.else_scope():
out_ptr[index] = update

with ib.new_scope():
nthread_bx = ceil_div(ni, nthread_tx)
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * nthread_tx + tx

with ib.if_scope(tid == ni - 1):
# The last element can always update.
index = indices_ptr[tid]
update = updates_ptr[tid]
do_update(ib, index, update)

with ib.else_scope():
with ib.if_scope(tid < ni - 1):
index = indices_ptr[tid]
index_next = indices_ptr[tid + 1]

# If the next neighbor in the sorted list of indices has a different index,
# that means thread tid is the last one to have this index.
# This thread can update the output.
with ib.if_scope(index != index_next):
update = updates_ptr[tid]
do_update(ib, index, update)

return ib.get()


def scatter(data, indices, updates, axis=0):
"""Update data at positions defined by indices with values in updates
Expand Down Expand Up @@ -458,9 +550,21 @@ def update_func(dst_ptr, dst_index, update):

out_shape = data.shape
out_buf = tvm.tir.decl_buffer(out_shape, data.dtype, "out_buf")

in_bufs = [data]

if rank == 1 and is_thrust_available():
ir_funcs[1] = gen_scatter_1d_thrust
indices_sorted, updates_sorted = stable_sort_by_key_thrust(
indices, updates, for_scatter=True
)
in_bufs += [indices_sorted, updates_sorted]
else:
in_bufs += [indices, updates]

out = te.extern(
[out_shape],
[data, indices, updates],
in_bufs,
lambda ins, outs: ir_funcs[rank](ins[0], ins[1], ins[2], axis, outs[0], update_func),
dtype=data.dtype,
out_buffers=[out_buf],
Expand Down
59 changes: 58 additions & 1 deletion python/tvm/topi/cuda/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument
"""Argsort operator """
"""Sort related operators """
import tvm
from tvm import te
from tvm._ffi import get_global_func

from .injective import schedule_injective_from_existing
from ..math import identity
Expand Down Expand Up @@ -597,3 +598,59 @@ def schedule_topk(outs):
The computation schedule for the op.
"""
return _schedule_sort(outs)


def stable_sort_by_key_thrust(keys, values, for_scatter=False):
"""Sort values with respect to keys using thrust.
Both keys and values will be sorted and returned.
Sorting is done via stable sort, so relative ordering among
ties are preserved.
Parameters
----------
keys: tvm.te.Tensor
The 1D input keys.
values : tvm.te.Tensor,
The 1D input values.
for_scatter: bool, optional
If True, negative keys are interpreted as negative indices.
Before sorting, negative indices are converted to corresponding positive indices.
The output keys (indices) are all positive.
This option is introduced to optimize the scatter implementation.
Returns
-------
keys_sorted : tvm.te.Tensor
The sorted keys
values_sorted : tvm.te.Tensor
The values sorted with respect to the keys
"""
keys_buf = tvm.tir.decl_buffer(keys.shape, keys.dtype, "keys_buf", data_alignment=8)
values_buf = tvm.tir.decl_buffer(values.shape, values.dtype, "values_buf", data_alignment=8)
out_bufs = [
tvm.tir.decl_buffer(keys.shape, keys.dtype, "keys_buf", data_alignment=8),
tvm.tir.decl_buffer(keys.shape, values.dtype, "values_buf", data_alignment=8),
]
out = te.extern(
[keys.shape, values.shape],
[keys, values],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.thrust.stable_sort_by_key", ins[0], ins[1], outs[0], outs[1], for_scatter
),
in_buffers=[keys_buf, values_buf],
out_buffers=out_bufs,
dtype=[keys.dtype, values.dtype],
name="stable_sort_by_key",
tag="stable_sort_by_key",
)
return out[0], out[1]


def is_thrust_available():
"""
Test if thrust based sorting ops are available.
"""
return get_global_func("tvm.contrib.thrust.sort", allow_missing=True) is not None
73 changes: 73 additions & 0 deletions src/runtime/contrib/thrust/thrust.cu
Original file line number Diff line number Diff line change
Expand Up @@ -163,5 +163,78 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort")
thrust_sort_common(input, values_out, indices_out, is_ascend, get_sort_len,
data_dtype, out_dtype);
});

template<typename KeyType, typename ValueType>
void thrust_stable_sort_by_key(DLTensor* keys_in,
DLTensor* values_in,
DLTensor* keys_out,
DLTensor* values_out,
bool for_scatter) {
const auto size = keys_in->shape[0];
thrust::device_ptr<KeyType> keys_in_ptr(static_cast<KeyType *>(keys_in->data));
thrust::device_ptr<ValueType> values_in_ptr(static_cast<ValueType *>(values_in->data));
thrust::device_ptr<KeyType> keys_out_ptr(static_cast<KeyType *>(keys_out->data));
thrust::device_ptr<ValueType> values_out_ptr(static_cast<ValueType *>(values_out->data));

if (for_scatter) {
thrust::transform(keys_in_ptr, keys_in_ptr + size, keys_out_ptr, [size] __device__(KeyType k) {
if (k < 0) return k + static_cast<KeyType>(size);
return k;
});
} else {
thrust::copy(keys_in_ptr, keys_in_ptr + size, keys_out_ptr);
}
thrust::copy(values_in_ptr, values_in_ptr + size, values_out_ptr);

thrust::stable_sort_by_key(keys_out_ptr, keys_out_ptr + size, values_out_ptr);
}

TVM_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key")
.set_body([](TVMArgs args, TVMRetValue* ret) {
ICHECK_GE(args.num_args, 5);
DLTensor* keys_in = args[0];
DLTensor* values_in = args[1];
DLTensor* keys_out = args[2];
DLTensor* values_out = args[3];
bool for_scatter = args[4];

auto key_dtype = DLDataType2String(keys_in->dtype);
auto value_dtype = DLDataType2String(values_in->dtype);

if (key_dtype == "int32") {
if (value_dtype == "int32") {
thrust_stable_sort_by_key<int, int>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else if (value_dtype == "float32") {
thrust_stable_sort_by_key<int, float>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else {
LOG(FATAL) << "Unsupported value dtype: " << value_dtype;
}
} else if (key_dtype == "int64") {
if (value_dtype == "int32") {
thrust_stable_sort_by_key<int64_t, int>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else if (value_dtype == "float32") {
thrust_stable_sort_by_key<int64_t, float>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else {
LOG(FATAL) << "Unsupported value dtype: " << value_dtype;
}
} else if (key_dtype == "float32") {
if (value_dtype == "int32") {
thrust_stable_sort_by_key<float, int>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else if (value_dtype == "float32") {
thrust_stable_sort_by_key<float, float>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else {
LOG(FATAL) << "Unsupported value dtype: " << value_dtype;
}
} else {
LOG(FATAL) << "Unsupported key dtype: " << key_dtype;
}
});

} // namespace contrib
} // namespace tvm
34 changes: 34 additions & 0 deletions tests/python/contrib/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import tvm
import tvm.testing
from tvm import te
from tvm.topi.cuda import stable_sort_by_key_thrust, is_thrust_available
import numpy as np


Expand Down Expand Up @@ -90,6 +91,39 @@ def test_sort_np():
tvm.testing.assert_allclose(c.asnumpy(), np_out, rtol=1e-5)


def test_thrust_stable_sort_by_key():
if not is_thrust_available():
print("skip because thrust is not enabled...")
return

size = 6
keys = te.placeholder((size,), name="keys", dtype="int32")
values = te.placeholder((size,), name="values", dtype="int32")

keys_out, values_out = stable_sort_by_key_thrust(keys, values)

ctx = tvm.gpu(0)
target = "cuda"
s = te.create_schedule([keys_out.op, values_out.op])
f = tvm.build(s, [keys, values, keys_out, values_out], target)

keys_np = np.array([1, 4, 2, 8, 2, 7], np.int32)
values_np = np.random.randint(0, 10, size=(size,)).astype(np.int32)
keys_np_out = np.zeros(keys_np.shape, np.int32)
values_np_out = np.zeros(values_np.shape, np.int32)
keys_in = tvm.nd.array(keys_np, ctx)
values_in = tvm.nd.array(values_np, ctx)
keys_out = tvm.nd.array(keys_np_out, ctx)
values_out = tvm.nd.array(values_np_out, ctx)
f(keys_in, values_in, keys_out, values_out)

ref_keys_out = np.sort(keys_np)
ref_values_out = np.array([values_np[i] for i in np.argsort(keys_np)])
tvm.testing.assert_allclose(keys_out.asnumpy(), ref_keys_out, rtol=1e-5)
tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5)


if __name__ == "__main__":
test_sort()
test_sort_np()
test_thrust_stable_sort_by_key()

0 comments on commit 465cd14

Please sign in to comment.