Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize string gather performance for large strings #7980

Merged
79 changes: 46 additions & 33 deletions cpp/include/cudf/strings/detail/gather.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,13 @@ namespace detail {
// Helper function for loading 16B from a potentially unaligned memory location to registers.
__forceinline__ __device__ uint4 load_uint4(const char* ptr)
{
unsigned int* aligned_ptr = (unsigned int*)((size_t)ptr & ~(3));
uint4 regs = {0, 0, 0, 0};
auto const offset = reinterpret_cast<std::uintptr_t>(ptr) % 4;
auto const* aligned_ptr = reinterpret_cast<unsigned int const*>(ptr - offset);
auto const shift = offset * 8;

regs.x = aligned_ptr[0];
regs.y = aligned_ptr[1];
regs.z = aligned_ptr[2];
regs.w = aligned_ptr[3];
uint tail = aligned_ptr[4];

unsigned int shift = ((size_t)ptr & 3) * 8;
uint4 regs = {aligned_ptr[0], aligned_ptr[1], aligned_ptr[2], aligned_ptr[3]};
uint tail = 0;
if (shift) tail = aligned_ptr[4];

regs.x = __funnelshift_r(regs.x, regs.y, shift);
regs.y = __funnelshift_r(regs.y, regs.z, shift);
Expand All @@ -68,40 +65,56 @@ __global__ void gather_chars_fn_string_parallel(StringIterator strings_begin,
MapIterator string_indices,
size_type total_out_strings)
{
constexpr size_t datatype_size = sizeof(uint4);
constexpr size_t threads_per_warp = 32;
gaohao95 marked this conversation as resolved.
Show resolved Hide resolved

int global_thread_id = blockIdx.x * blockDim.x + threadIdx.x;
int global_warp_id = global_thread_id / 32;
int warp_lane = global_thread_id % 32;
int nwarps = gridDim.x * blockDim.x / 32;
int global_warp_id = global_thread_id / threads_per_warp;
int warp_lane = global_thread_id % threads_per_warp;
int nwarps = gridDim.x * blockDim.x / threads_per_warp;

size_t alignment_offset = reinterpret_cast<size_t>(out_chars) & 15;
uint4* out_chars_aligned = reinterpret_cast<uint4*>(out_chars - alignment_offset);
auto const alignment_offset = reinterpret_cast<std::uintptr_t>(out_chars) % datatype_size;
uint4* out_chars_aligned = reinterpret_cast<uint4*>(out_chars - alignment_offset);

for (size_type istring = global_warp_id; istring < total_out_strings; istring += nwarps) {
auto out_start = out_offsets[istring];
auto out_end = out_offsets[istring + 1];
auto const out_start = out_offsets[istring];
auto const out_end = out_offsets[istring + 1];

// This check is necessary because string_indices[istring] may be out of bound.
if (out_start == out_end) continue;

const char* in_start = strings_begin[string_indices[istring]].data();

int32_t out_start_aligned = (out_start + alignment_offset + 15) / 16 * 16 - alignment_offset;
int32_t out_end_aligned = (out_end + alignment_offset) / 16 * 16 - alignment_offset;

for (size_type ichar = out_start_aligned + warp_lane * 16; ichar < out_end_aligned;
ichar += 32 * 16) {
*(out_chars_aligned + (ichar + alignment_offset) / 16) =
// Both `out_start_aligned` and `out_end_aligned` are indices into `out_chars`.
// `out_start_aligned` is the first 16B aligned memory location after `out_start`.
// `out_end_aligned` is the last 16B aligned memory location before `out_end`. Characters
// between `[out_start_aligned, out_end_aligned)` will be copied using uint4.
int32_t out_start_aligned =
(out_start + alignment_offset + datatype_size - 1) / datatype_size * datatype_size -
alignment_offset;
int32_t out_end_aligned =
(out_end + alignment_offset) / datatype_size * datatype_size - alignment_offset;

for (size_type ichar = out_start_aligned + warp_lane * datatype_size; ichar < out_end_aligned;
ichar += threads_per_warp * datatype_size) {
*(out_chars_aligned + (ichar + alignment_offset) / datatype_size) =
gaohao95 marked this conversation as resolved.
Show resolved Hide resolved
load_uint4(in_start + ichar - out_start);
}

// Tail logic: copy characters of the current string outside `[out_start_aligned,
// out_end_aligned)`.
if (out_end_aligned <= out_start_aligned) {
gaohao95 marked this conversation as resolved.
Show resolved Hide resolved
// In this case, `[out_start_aligned, out_end_aligned)` is an empty set, and we copy the
// entire string. Note that for 16B alignment, the maximum number of characters in this string
// is less than 32, so for each thread in the warp, copying one byte is enough.
int32_t ichar = out_start + warp_lane;
if (ichar < out_end) { out_chars[ichar] = in_start[warp_lane]; }
} else {
// Copy characters in range `[out_start, out_start_aligned)`.
if (out_start + warp_lane < out_start_aligned) {
out_chars[out_start + warp_lane] = in_start[warp_lane];
}

// Copy characters in range `[out_end_aligned, out_end)`.
int32_t ichar = out_end_aligned + warp_lane;
if (ichar < out_end) { out_chars[ichar] = in_start[ichar - out_start]; }
}
Expand All @@ -112,10 +125,10 @@ __global__ void gather_chars_fn_string_parallel(StringIterator strings_begin,
// This strategy assigns characters to threads, and uses binary search for getting the string
// index. To improve the binary search performance, fixed number of strings per threadblock is
// used. This strategy is best suited for small strings.
constexpr static int strings_per_threadblock = 32;

// Binary search `value` in `offsets` of length `nelements`. Require `nelements` to be less than or
// equal to `strings_per_threadblock`. Require `strings_per_threadblock` to be an exponential of 2.
template <int strings_per_threadblock>
__forceinline__ __device__ size_type binary_search(int32_t* offsets,
gaohao95 marked this conversation as resolved.
Show resolved Hide resolved
int32_t value,
size_type nelements)
Expand All @@ -128,7 +141,7 @@ __forceinline__ __device__ size_type binary_search(int32_t* offsets,
return idx;
}

template <typename StringIterator, typename MapIterator>
template <int strings_per_threadblock, typename StringIterator, typename MapIterator>
__global__ void gather_chars_fn_char_parallel(StringIterator strings_begin,
char* out_chars,
cudf::device_span<int32_t const> const out_offsets,
Expand Down Expand Up @@ -156,8 +169,8 @@ __global__ void gather_chars_fn_char_parallel(StringIterator strings_begin,
out_ibyte < out_offsets_threadblock[strings_current_threadblock];
out_ibyte += blockDim.x) {
// binary search for the string index corresponding to out_ibyte
size_type string_idx =
binary_search(out_offsets_threadblock, out_ibyte, strings_current_threadblock);
size_type string_idx = binary_search<strings_per_threadblock>(
out_offsets_threadblock, out_ibyte, strings_current_threadblock);

// calculate which character to load within the string
int32_t icharacter = out_ibyte - out_offsets_threadblock[string_idx];
Expand Down Expand Up @@ -210,12 +223,12 @@ std::unique_ptr<cudf::column> gather_chars(StringIterator strings_begin,
stream.value()>>>(
strings_begin, d_chars, offsets, map_begin, output_count);
} else {
gather_chars_fn_char_parallel<<<(output_count + strings_per_threadblock - 1) /
strings_per_threadblock,
128,
0,
stream.value()>>>(
strings_begin, d_chars, offsets, map_begin, output_count);
constexpr int strings_per_threadblock = 32;
gather_chars_fn_char_parallel<strings_per_threadblock>
<<<(output_count + strings_per_threadblock - 1) / strings_per_threadblock,
128,
0,
stream.value()>>>(strings_begin, d_chars, offsets, map_begin, output_count);
}

return chars_column;
Expand Down