From 19775b08ca99e22c036e7f81ef6dc041f9f5da8a Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Sun, 24 Mar 2024 13:07:40 -0600 Subject: [PATCH] fix block size conversion --- ggml-cuda.cu | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index dac3d60b724b9..7cbc12479dd30 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -625,6 +625,7 @@ static_assert(sizeof(block_iq4_xs) == sizeof(ggml_fp16_t) + sizeof(uint16_t) + Q #define CUDA_ACC_BLOCK_SIZE 256 #define CUDA_IM2COL_BLOCK_SIZE 256 #define CUDA_POOL2D_BLOCK_SIZE 256 +#define CUDA_FA_CONVERT_BLOCK_SIZE 256 #define CUDA_Q8_0_NE_ALIGN 2048 @@ -12260,8 +12261,8 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * src0, const ggml_tensor cudaMallocAsync((void **)&d_qkv, ggml_nelements(dst) * sizeof(half), main_stream); // convert query to half - int num_blocks = (ggml_nelements(src0) + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; - flash_ext_f32_f16<<>>((float*)src0_extra->data_device[g_main_device], d_query, ggml_nelements(src0)); + int num_blocks = (ggml_nelements(src0) + CUDA_FA_CONVERT_BLOCK_SIZE - 1) / CUDA_FA_CONVERT_BLOCK_SIZE; + flash_ext_f32_f16<<>>((float*)src0_extra->data_device[g_main_device], d_query, ggml_nelements(src0)); flash_attn_fwd( d_query, @@ -12272,8 +12273,8 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * src0, const ggml_tensor cudaFreeAsync(d_softmax_lse, main_stream); // convert output from f16 to f32 - num_blocks = (ggml_nelements(dst) + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; - flash_ext_f16_f32<<>>(d_qkv, (float*)dst_extra->data_device[g_main_device], ggml_nelements(dst)); + num_blocks = (ggml_nelements(dst) + CUDA_FA_CONVERT_BLOCK_SIZE - 1) / CUDA_FA_CONVERT_BLOCK_SIZE; + flash_ext_f16_f32<<>>(d_qkv, (float*)dst_extra->data_device[g_main_device], ggml_nelements(dst)); cudaFreeAsync(d_query, main_stream); cudaFreeAsync(d_qkv, main_stream); return;