Skip to content

Commit

Permalink
fix block size conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
FSSRepo committed Mar 24, 2024
1 parent 0a13dfe commit 19775b0
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,7 @@ static_assert(sizeof(block_iq4_xs) == sizeof(ggml_fp16_t) + sizeof(uint16_t) + Q
#define CUDA_ACC_BLOCK_SIZE 256
#define CUDA_IM2COL_BLOCK_SIZE 256
#define CUDA_POOL2D_BLOCK_SIZE 256
#define CUDA_FA_CONVERT_BLOCK_SIZE 256

#define CUDA_Q8_0_NE_ALIGN 2048

Expand Down Expand Up @@ -12260,8 +12261,8 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * src0, const ggml_tensor
cudaMallocAsync((void **)&d_qkv, ggml_nelements(dst) * sizeof(half), main_stream);

// convert query to half
int num_blocks = (ggml_nelements(src0) + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
flash_ext_f32_f16<<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, main_stream>>>((float*)src0_extra->data_device[g_main_device], d_query, ggml_nelements(src0));
int num_blocks = (ggml_nelements(src0) + CUDA_FA_CONVERT_BLOCK_SIZE - 1) / CUDA_FA_CONVERT_BLOCK_SIZE;
flash_ext_f32_f16<<<num_blocks, CUDA_FA_CONVERT_BLOCK_SIZE, 0, main_stream>>>((float*)src0_extra->data_device[g_main_device], d_query, ggml_nelements(src0));

flash_attn_fwd(
d_query,
Expand All @@ -12272,8 +12273,8 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * src0, const ggml_tensor
cudaFreeAsync(d_softmax_lse, main_stream);

// convert output from f16 to f32
num_blocks = (ggml_nelements(dst) + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
flash_ext_f16_f32<<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, main_stream>>>(d_qkv, (float*)dst_extra->data_device[g_main_device], ggml_nelements(dst));
num_blocks = (ggml_nelements(dst) + CUDA_FA_CONVERT_BLOCK_SIZE - 1) / CUDA_FA_CONVERT_BLOCK_SIZE;
flash_ext_f16_f32<<<num_blocks, CUDA_FA_CONVERT_BLOCK_SIZE, 0, main_stream>>>(d_qkv, (float*)dst_extra->data_device[g_main_device], ggml_nelements(dst));
cudaFreeAsync(d_query, main_stream);
cudaFreeAsync(d_qkv, main_stream);
return;
Expand Down

0 comments on commit 19775b0

Please sign in to comment.