Skip to content
This repository has been archived by the owner on Nov 25, 2024. It is now read-only.

Commit

Permalink
gather after sorting indices
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuofan1123 committed May 27, 2024
1 parent 7352f1c commit 4ec8bba
Show file tree
Hide file tree
Showing 12 changed files with 338 additions and 17 deletions.
77 changes: 77 additions & 0 deletions cpp/src/wholememory_ops/functions/gather_func.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ wholememory_error_code_t gather_integer_int32_func(wholememory_gref_t embedding_
wholememory_matrix_description_t embedding_desc,
void* indices,
wholememory_array_description_t indices_desc,
bool gather_with_sorted_ids,
int64_t* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
Expand All @@ -32,6 +34,8 @@ wholememory_error_code_t gather_integer_int64_func(wholememory_gref_t embedding_
wholememory_matrix_description_t embedding_desc,
void* indices,
wholememory_array_description_t indices_desc,
bool gather_with_sorted_ids,
int64_t* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
Expand All @@ -40,6 +44,8 @@ wholememory_error_code_t gather_floating_int32_func(wholememory_gref_t embedding
wholememory_matrix_description_t embedding_desc,
void* indices,
wholememory_array_description_t indices_desc,
bool gather_with_sorted_ids,
int64_t* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
Expand All @@ -48,6 +54,8 @@ wholememory_error_code_t gather_floating_int64_func(wholememory_gref_t embedding
wholememory_matrix_description_t embedding_desc,
void* indices,
wholememory_array_description_t indices_desc,
bool gather_with_sorted_ids,
int64_t* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
Expand Down Expand Up @@ -76,6 +84,8 @@ wholememory_error_code_t gather_func(wholememory_gref_t embedding_gref,
wholememory_matrix_description_t,
void* indices,
wholememory_array_description_t,
bool,
int64_t*,
void*,
wholememory_matrix_description_t,
cudaStream_t,
Expand All @@ -97,6 +107,73 @@ wholememory_error_code_t gather_func(wholememory_gref_t embedding_gref,
embedding_desc,
indices,
indices_desc,
false,
nullptr,
output,
output_desc,
stream,
gather_sms);
} catch (const wholememory::cuda_error& rle) {
return WHOLEMEMORY_LOGIC_ERROR;
} catch (const wholememory::logic_error& le) {
return WHOLEMEMORY_LOGIC_ERROR;
} catch (...) {
return WHOLEMEMORY_LOGIC_ERROR;
}
return WHOLEMEMORY_SUCCESS;
}

wholememory_error_code_t gather_with_sorted_ids_func(wholememory_gref_t embedding_gref,
wholememory_matrix_description_t embedding_desc,
void* indices,
wholememory_array_description_t indices_desc,
int64_t* raw_indices,
wholememory_array_description_t raw_indices_desc,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
int gather_sms)
{
try {
bool embedding_is_float = wholememory_dtype_is_floating_number(embedding_desc.dtype);
WHOLEMEMORY_CHECK(embedding_is_float ||
wholememory_dtype_is_integer_number(embedding_desc.dtype));
bool output_is_float = wholememory_dtype_is_floating_number(output_desc.dtype);
WHOLEMEMORY_CHECK(output_is_float || wholememory_dtype_is_integer_number(output_desc.dtype));
WHOLEMEMORY_EXPECTS(
embedding_is_float == output_is_float,
"embedding and output should be same number type, e.g. floating number or integer number.");
if (indices_desc.size == 0) { return WHOLEMEMORY_SUCCESS; }
WHOLEMEMORY_CHECK(indices_desc.size == raw_indices_desc.size);
wholememory_error_code_t (*p_gather_func)(wholememory_gref_t,
wholememory_matrix_description_t,
void* indices,
wholememory_array_description_t,
bool,
int64_t*,
void*,
wholememory_matrix_description_t,
cudaStream_t,
int) = nullptr;
if (embedding_is_float) {
if (indices_desc.dtype == WHOLEMEMORY_DT_INT) {
p_gather_func = gather_floating_int32_func;
} else {
p_gather_func = gather_floating_int64_func;
}
} else {
if (indices_desc.dtype == WHOLEMEMORY_DT_INT) {
p_gather_func = gather_integer_int32_func;
} else {
p_gather_func = gather_integer_int64_func;
}
}
return p_gather_func(embedding_gref,
embedding_desc,
indices,
indices_desc,
true,
raw_indices,
output,
output_desc,
stream,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,15 @@ void gather_floating_int32_temp_func(wholememory_gref_t embedding_gref,
wholememory_matrix_description_t embedding_desc,
void* indices,
int64_t indice_count,
bool gather_with_sorted_ids,
int64_t* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
int gather_sms)
{
gather_temp_func<EmbeddingT, int32_t, OutputT>(
embedding_gref, embedding_desc, indices, indice_count, output, output_desc, stream, gather_sms);
embedding_gref, embedding_desc, indices, indice_count, gather_with_sorted_ids, raw_indices, output, output_desc, stream, gather_sms);
}

REGISTER_DISPATCH_TWO_TYPES(GatherFuncFloatingInt32,
Expand All @@ -45,6 +47,8 @@ wholememory_error_code_t gather_floating_int32_func(wholememory_gref_t embedding
wholememory_matrix_description_t embedding_desc,
void* indices,
wholememory_array_description_t indices_desc,
bool gather_with_sorted_ids,
int64_t* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
Expand All @@ -63,6 +67,8 @@ wholememory_error_code_t gather_floating_int32_func(wholememory_gref_t embedding
static_cast<char*>(indices) +
indices_desc.storage_offset * wholememory_dtype_get_element_size(indices_desc.dtype),
indices_desc.size,
gather_with_sorted_ids,
raw_indices,
output,
output_desc,
stream,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,15 @@ void gather_floating_int64_temp_func(wholememory_gref_t embedding_gref,
wholememory_matrix_description_t embedding_desc,
void* indices,
int64_t indice_count,
bool gather_with_sorted_ids,
int64_t* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
int gather_sms)
{
gather_temp_func<EmbeddingT, int64_t, OutputT>(
embedding_gref, embedding_desc, indices, indice_count, output, output_desc, stream, gather_sms);
embedding_gref, embedding_desc, indices, indice_count, gather_with_sorted_ids, raw_indices, output, output_desc, stream, gather_sms);
}

REGISTER_DISPATCH_TWO_TYPES(GatherFuncFloatingInt64,
Expand All @@ -45,6 +47,8 @@ wholememory_error_code_t gather_floating_int64_func(wholememory_gref_t embedding
wholememory_matrix_description_t embedding_desc,
void* indices,
wholememory_array_description_t indices_desc,
bool gather_with_sorted_ids,
int64_t* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
Expand All @@ -63,6 +67,8 @@ wholememory_error_code_t gather_floating_int64_func(wholememory_gref_t embedding
static_cast<char*>(indices) +
indices_desc.storage_offset * wholememory_dtype_get_element_size(indices_desc.dtype),
indices_desc.size,
gather_with_sorted_ids,
raw_indices,
output,
output_desc,
stream,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,15 @@ void gather_integer_int32_temp_func(wholememory_gref_t embedding_gref,
wholememory_matrix_description_t embedding_desc,
void* indices,
int64_t indice_count,
bool gather_with_sorted_ids,
int64_t* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
int gather_sms)
{
gather_temp_func<EmbeddingT, int32_t, OutputT>(
embedding_gref, embedding_desc, indices, indice_count, output, output_desc, stream, gather_sms);
embedding_gref, embedding_desc, indices, indice_count, gather_with_sorted_ids, raw_indices, output, output_desc, stream, gather_sms);
}

REGISTER_DISPATCH_TWO_TYPES(GatherFuncIntegerInt32,
Expand All @@ -45,6 +47,8 @@ wholememory_error_code_t gather_integer_int32_func(wholememory_gref_t embedding_
wholememory_matrix_description_t embedding_desc,
void* indices,
wholememory_array_description_t indices_desc,
bool gather_with_sorted_ids,
int64_t* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
Expand All @@ -63,6 +67,8 @@ wholememory_error_code_t gather_integer_int32_func(wholememory_gref_t embedding_
static_cast<char*>(indices) +
indices_desc.storage_offset * wholememory_dtype_get_element_size(indices_desc.dtype),
indices_desc.size,
gather_with_sorted_ids,
raw_indices,
output,
output_desc,
stream,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,15 @@ void gather_integer_int64_temp_func(wholememory_gref_t embedding_gref,
wholememory_matrix_description_t embedding_desc,
void* indices,
int64_t indice_count,
bool gather_with_sorted_ids,
int64_t* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
int gather_sms)
{
gather_temp_func<EmbeddingT, int64_t, OutputT>(
embedding_gref, embedding_desc, indices, indice_count, output, output_desc, stream, gather_sms);
embedding_gref, embedding_desc, indices, indice_count, gather_with_sorted_ids, raw_indices, output, output_desc, stream, gather_sms);
}

REGISTER_DISPATCH_TWO_TYPES(GatherFuncIntegerInt64,
Expand All @@ -45,6 +47,8 @@ wholememory_error_code_t gather_integer_int64_func(wholememory_gref_t embedding_
wholememory_matrix_description_t embedding_desc,
void* indices,
wholememory_array_description_t indices_desc,
bool gather_with_sorted_ids,
int64_t* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
Expand All @@ -63,6 +67,8 @@ wholememory_error_code_t gather_integer_int64_func(wholememory_gref_t embedding_
static_cast<char*>(indices) +
indices_desc.storage_offset * wholememory_dtype_get_element_size(indices_desc.dtype),
indices_desc.size,
gather_with_sorted_ids,
raw_indices,
output,
output_desc,
stream,
Expand Down
16 changes: 14 additions & 2 deletions cpp/src/wholememory_ops/functions/gather_scatter_func.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,8 @@ __global__ void gather_func_kernel(wholememory_gref_t embedding_gref,
wholememory_matrix_description_t embedding_desc,
const IndexT* indices,
int64_t indice_count,
bool gather_with_sorted_ids,
int64_t* raw_indices,
OutputT* output,
wholememory_matrix_description_t output_desc)
{
Expand Down Expand Up @@ -284,7 +286,8 @@ __global__ void gather_func_kernel(wholememory_gref_t embedding_gref,

for (int64_t output_idx = warp_id; output_idx < indice_count;
output_idx += gridDim.x * (blockDim.x / 32)) {
OutputT* output_ptr = output + output_desc.storage_offset + output_stride * output_idx;
int64_t raw_output_idx = gather_with_sorted_ids ? raw_indices[output_idx] : output_idx;
OutputT* output_ptr = output + output_desc.storage_offset + output_stride * raw_output_idx;
if (!use_shm) { my_shared = output_ptr; }
int64_t embedding_table_idx = indices[output_idx];
if (embedding_table_idx < 0) continue;
Expand Down Expand Up @@ -323,6 +326,8 @@ __global__ void gather_func_sub_warp_kernel(wholememory_gref_t embedding_gref,
wholememory_matrix_description_t embedding_desc,
const IndexT* indices,
int64_t indice_count,
bool gather_with_sorted_ids,
int64_t* raw_indices,
OutputT* output,
wholememory_matrix_description_t output_desc)
{
Expand All @@ -345,7 +350,8 @@ __global__ void gather_func_sub_warp_kernel(wholememory_gref_t embedding_gref,
typed_data_vector<EmbeddingT, ALIGNMENT> embeddings;
typed_data_vector<OutputT, ALIGNMENT> outputs;
for (int64_t output_idx = sub_warp_id; output_idx < indice_count; output_idx += sub_warp_num) {
OutputT* output_ptr = output + output_desc.storage_offset + output_stride * output_idx;
int64_t raw_output_idx = gather_with_sorted_ids ? raw_indices[output_idx] : output_idx;
OutputT* output_ptr = output + output_desc.storage_offset + output_stride * raw_output_idx;
IndexT embedding_table_idx = indices[output_idx];
if (embedding_table_idx < 0) continue;
int64_t embedding_offset =
Expand All @@ -370,6 +376,8 @@ void gather_temp_func(wholememory_gref_t embedding_gref,
wholememory_matrix_description_t embedding_desc,
void* indices,
int64_t indice_count,
bool gather_with_sorted_ids,
int64_t* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
Expand All @@ -392,6 +400,8 @@ void gather_temp_func(wholememory_gref_t embedding_gref,
wholememory_matrix_description_t,
const IndexT*,
int64_t,
bool,
int64_t*,
OutputT*,
wholememory_matrix_description_t) = nullptr;

Expand Down Expand Up @@ -495,6 +505,8 @@ void gather_temp_func(wholememory_gref_t embedding_gref,
embedding_desc,
static_cast<const IndexT*>(indices),
indice_count,
gather_with_sorted_ids,
raw_indices,
static_cast<OutputT*>(output),
output_desc);
WM_CUDA_CHECK(cudaGetLastError());
Expand Down
Loading

0 comments on commit 4ec8bba

Please sign in to comment.