Skip to content
This repository has been archived by the owner on Nov 25, 2024. It is now read-only.

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuofan1123 committed May 28, 2024
1 parent 4ec8bba commit 36a7456
Show file tree
Hide file tree
Showing 12 changed files with 52 additions and 39 deletions.
15 changes: 8 additions & 7 deletions cpp/src/wholememory_ops/functions/gather_func.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ wholememory_error_code_t gather_integer_int32_func(wholememory_gref_t embedding_
void* indices,
wholememory_array_description_t indices_desc,
bool gather_with_sorted_ids,
int64_t* raw_indices,
void* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
Expand All @@ -35,7 +35,7 @@ wholememory_error_code_t gather_integer_int64_func(wholememory_gref_t embedding_
void* indices,
wholememory_array_description_t indices_desc,
bool gather_with_sorted_ids,
int64_t* raw_indices,
void* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
Expand All @@ -45,7 +45,7 @@ wholememory_error_code_t gather_floating_int32_func(wholememory_gref_t embedding
void* indices,
wholememory_array_description_t indices_desc,
bool gather_with_sorted_ids,
int64_t* raw_indices,
void* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
Expand All @@ -55,7 +55,7 @@ wholememory_error_code_t gather_floating_int64_func(wholememory_gref_t embedding
void* indices,
wholememory_array_description_t indices_desc,
bool gather_with_sorted_ids,
int64_t* raw_indices,
void* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
Expand Down Expand Up @@ -85,7 +85,7 @@ wholememory_error_code_t gather_func(wholememory_gref_t embedding_gref,
void* indices,
wholememory_array_description_t,
bool,
int64_t*,
void*,
void*,
wholememory_matrix_description_t,
cudaStream_t,
Expand Down Expand Up @@ -127,7 +127,7 @@ wholememory_error_code_t gather_with_sorted_ids_func(wholememory_gref_t embeddin
wholememory_matrix_description_t embedding_desc,
void* indices,
wholememory_array_description_t indices_desc,
int64_t* raw_indices,
void* raw_indices,
wholememory_array_description_t raw_indices_desc,
void* output,
wholememory_matrix_description_t output_desc,
Expand All @@ -145,12 +145,13 @@ wholememory_error_code_t gather_with_sorted_ids_func(wholememory_gref_t embeddin
"embedding and output should be same number type, e.g. floating number or integer number.");
if (indices_desc.size == 0) { return WHOLEMEMORY_SUCCESS; }
WHOLEMEMORY_CHECK(indices_desc.size == raw_indices_desc.size);
WHOLEMEMORY_CHECK(indices_desc.dtype == raw_indices_desc.dtype);
wholememory_error_code_t (*p_gather_func)(wholememory_gref_t,
wholememory_matrix_description_t,
void* indices,
wholememory_array_description_t,
bool,
int64_t*,
void*,
void*,
wholememory_matrix_description_t,
cudaStream_t,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ void gather_floating_int32_temp_func(wholememory_gref_t embedding_gref,
void* indices,
int64_t indice_count,
bool gather_with_sorted_ids,
int64_t* raw_indices,
void* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
Expand All @@ -48,7 +48,7 @@ wholememory_error_code_t gather_floating_int32_func(wholememory_gref_t embedding
void* indices,
wholememory_array_description_t indices_desc,
bool gather_with_sorted_ids,
int64_t* raw_indices,
void* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ void gather_floating_int64_temp_func(wholememory_gref_t embedding_gref,
void* indices,
int64_t indice_count,
bool gather_with_sorted_ids,
int64_t* raw_indices,
void* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
Expand All @@ -48,7 +48,7 @@ wholememory_error_code_t gather_floating_int64_func(wholememory_gref_t embedding
void* indices,
wholememory_array_description_t indices_desc,
bool gather_with_sorted_ids,
int64_t* raw_indices,
void* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ void gather_integer_int32_temp_func(wholememory_gref_t embedding_gref,
void* indices,
int64_t indice_count,
bool gather_with_sorted_ids,
int64_t* raw_indices,
void* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
Expand All @@ -48,7 +48,7 @@ wholememory_error_code_t gather_integer_int32_func(wholememory_gref_t embedding_
void* indices,
wholememory_array_description_t indices_desc,
bool gather_with_sorted_ids,
int64_t* raw_indices,
void* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ void gather_integer_int64_temp_func(wholememory_gref_t embedding_gref,
void* indices,
int64_t indice_count,
bool gather_with_sorted_ids,
int64_t* raw_indices,
void* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
Expand All @@ -48,7 +48,7 @@ wholememory_error_code_t gather_integer_int64_func(wholememory_gref_t embedding_
void* indices,
wholememory_array_description_t indices_desc,
bool gather_with_sorted_ids,
int64_t* raw_indices,
void* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
Expand Down
14 changes: 7 additions & 7 deletions cpp/src/wholememory_ops/functions/gather_scatter_func.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ __global__ void gather_func_kernel(wholememory_gref_t embedding_gref,
const IndexT* indices,
int64_t indice_count,
bool gather_with_sorted_ids,
int64_t* raw_indices,
const IndexT* raw_indices,
OutputT* output,
wholememory_matrix_description_t output_desc)
{
Expand Down Expand Up @@ -286,7 +286,7 @@ __global__ void gather_func_kernel(wholememory_gref_t embedding_gref,

for (int64_t output_idx = warp_id; output_idx < indice_count;
output_idx += gridDim.x * (blockDim.x / 32)) {
int64_t raw_output_idx = gather_with_sorted_ids ? raw_indices[output_idx] : output_idx;
int64_t raw_output_idx = gather_with_sorted_ids ? (int64_t) (raw_indices[output_idx]) : output_idx;
OutputT* output_ptr = output + output_desc.storage_offset + output_stride * raw_output_idx;
if (!use_shm) { my_shared = output_ptr; }
int64_t embedding_table_idx = indices[output_idx];
Expand Down Expand Up @@ -327,7 +327,7 @@ __global__ void gather_func_sub_warp_kernel(wholememory_gref_t embedding_gref,
const IndexT* indices,
int64_t indice_count,
bool gather_with_sorted_ids,
int64_t* raw_indices,
const IndexT* raw_indices,
OutputT* output,
wholememory_matrix_description_t output_desc)
{
Expand All @@ -350,7 +350,7 @@ __global__ void gather_func_sub_warp_kernel(wholememory_gref_t embedding_gref,
typed_data_vector<EmbeddingT, ALIGNMENT> embeddings;
typed_data_vector<OutputT, ALIGNMENT> outputs;
for (int64_t output_idx = sub_warp_id; output_idx < indice_count; output_idx += sub_warp_num) {
int64_t raw_output_idx = gather_with_sorted_ids ? raw_indices[output_idx] : output_idx;
int64_t raw_output_idx = gather_with_sorted_ids ? (int64_t) (raw_indices[output_idx]) : output_idx;
OutputT* output_ptr = output + output_desc.storage_offset + output_stride * raw_output_idx;
IndexT embedding_table_idx = indices[output_idx];
if (embedding_table_idx < 0) continue;
Expand All @@ -377,7 +377,7 @@ void gather_temp_func(wholememory_gref_t embedding_gref,
void* indices,
int64_t indice_count,
bool gather_with_sorted_ids,
int64_t* raw_indices,
void* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
Expand All @@ -401,7 +401,7 @@ void gather_temp_func(wholememory_gref_t embedding_gref,
const IndexT*,
int64_t,
bool,
int64_t*,
const IndexT*,
OutputT*,
wholememory_matrix_description_t) = nullptr;

Expand Down Expand Up @@ -506,7 +506,7 @@ void gather_temp_func(wholememory_gref_t embedding_gref,
static_cast<const IndexT*>(indices),
indice_count,
gather_with_sorted_ids,
raw_indices,
static_cast<const IndexT*>(raw_indices),
static_cast<OutputT*>(output),
output_desc);
WM_CUDA_CHECK(cudaGetLastError());
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/wholememory_ops/functions/gather_scatter_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ wholememory_error_code_t gather_with_sorted_ids_func(wholememory_gref_t embeddin
wholememory_matrix_description_t embedding_desc,
void* indices,
wholememory_array_description_t indices_desc,
int64_t* raw_indices,
void* raw_indices,
wholememory_array_description_t raw_indices_desc,
void* output,
wholememory_matrix_description_t output_desc,
Expand Down
20 changes: 10 additions & 10 deletions cpp/src/wholememory_ops/functions/sort_indices_func.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ template <typename IndexT>
void sort_indices_temp_func(const void* indices_before_sort,
wholememory_array_description_t indices_desc,
void* indices_after_sort,
int64_t* raw_indices,
void* raw_indices,
wm_thrust_allocator* p_thrust_allocator,
wholememory_env_func_t* p_env_fns,
cudaStream_t stream)
Expand All @@ -52,8 +52,8 @@ void sort_indices_temp_func(const void* indices_before_sort,
WHOLEMEMORY_CHECK(index_type == WHOLEMEMORY_DT_INT || index_type == WHOLEMEMORY_DT_INT64);
wm_thrust_allocator& allocator = *p_thrust_allocator;

int64_t* seq_indices = reinterpret_cast<int64_t*>(allocator.allocate(
wholememory_get_memory_element_count_from_array(&indices_desc) * sizeof(int64_t)));
IndexT* seq_indices = reinterpret_cast<IndexT*>(allocator.allocate(
wholememory_get_memory_element_count_from_array(&indices_desc) * sizeof(IndexT)));
thrust::sequence(thrust::cuda::par_nosync(allocator).on(stream),
seq_indices,
seq_indices + indices_desc.size,
Expand All @@ -69,7 +69,7 @@ void sort_indices_temp_func(const void* indices_before_sort,
indices_to_sort,
sorted_indice,
seq_indices,
raw_indices,
static_cast<IndexT*>(raw_indices),
indices_desc.size,
0,
sizeof(UTypeT) * 8,
Expand All @@ -80,7 +80,7 @@ void sort_indices_temp_func(const void* indices_before_sort,
indices_to_sort,
sorted_indice,
seq_indices,
raw_indices,
static_cast<IndexT*>(raw_indices),
indices_desc.size,
0,
sizeof(UTypeT) * 8,
Expand All @@ -90,19 +90,19 @@ void sort_indices_temp_func(const void* indices_before_sort,
allocator.deallocate(static_cast<char*>(cub_temp_storage), temp_storage_bytes);
}

REGISTER_DISPATCH_ONE_TYPE(SortIDs, sort_indices_temp_func, SINT3264)
REGISTER_DISPATCH_ONE_TYPE(SortIndices, sort_indices_temp_func, SINT3264)

wholememory_error_code_t sort_indices_func(const void* indices_before_sort,
wholememory_array_description_t indice_desc,
void* indices_after_sort,
int64_t* raw_indices,
void* raw_indices,
wm_thrust_allocator* p_thrust_allocator,
wholememory_env_func_t* p_env_fns,
cudaStream_t stream)
{
try {
DISPATCH_ONE_TYPE(indice_desc.dtype,
SortIDs,
SortIndices,
indices_before_sort,
indice_desc,
indices_after_sort,
Expand All @@ -111,10 +111,10 @@ wholememory_error_code_t sort_indices_func(const void* indices_before_sort,
p_env_fns,
stream);
} catch (wholememory::cuda_error& wce) {
WHOLEMEMORY_ERROR("exchange_ids_func CUDA LOGIC Error %s\n", wce.what());
WHOLEMEMORY_ERROR("sort_indices_func CUDA LOGIC Error %s\n", wce.what());
return WHOLEMEMORY_CUDA_ERROR;
} catch (wholememory::logic_error& wle) {
WHOLEMEMORY_ERROR("exchange_ids_func LOGIC Error %s\n", wle.what());
WHOLEMEMORY_ERROR("sort_indices_func LOGIC Error %s\n", wle.what());
return WHOLEMEMORY_LOGIC_ERROR;
} catch (...) {
return WHOLEMEMORY_UNKNOW_ERROR;
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/wholememory_ops/functions/sort_indices_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace wholememory_ops {
wholememory_error_code_t sort_indices_func(const void* indices_before_sort,
wholememory_array_description_t indice_desc,
void* indices_after_sort,
int64_t* raw_indices,
void* raw_indices,
wm_thrust_allocator* p_thrust_allocator,
wholememory_env_func_t* p_env_fns,
cudaStream_t stream);
Expand Down
Loading

0 comments on commit 36a7456

Please sign in to comment.