diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 628f2dcbcd60b5..52c9204cc5910f 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -443,6 +443,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_ #define CUDA_SCALE_BLOCK_SIZE 256 #define CUDA_CLAMP_BLOCK_SIZE 256 #define CUDA_ROPE_BLOCK_SIZE 256 +#define CUDA_SOFT_MAX_BLOCK_SIZE 256 #define CUDA_ALIBI_BLOCK_SIZE 32 #define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32 #define CUDA_QUANTIZE_BLOCK_SIZE 256 @@ -4719,11 +4720,12 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int // the CUDA soft max implementation differs from the CPU implementation // instead of doubles floats are used -static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) { - const int rowx = blockDim.x*blockIdx.x + threadIdx.x; +static __global__ void soft_max_f32_warp(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) { + const int tid = threadIdx.x; + const int rowx = blockIdx.x; const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension - const int block_size = blockDim.y; - const int tid = threadIdx.y; + + const int block_size = blockDim.x; float max_val = -INFINITY; @@ -4763,6 +4765,66 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds } } +// use shared memory to reduce the number of global memory reads +static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) { + const int tid = threadIdx.x; + const int rowx = blockIdx.x; + const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension + + const int block_size = blockDim.x; + + __shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE]; + + buf[tid] = -INFINITY; + + for (int col = tid; col < ncols; col += block_size) { + const int ix = rowx*ncols + col; + const int iy = rowy*ncols + col; + buf[tid] = max(buf[tid], x[ix]*scale + (y ? y[iy] : 0.0f)); + } + + __syncthreads(); + + // find the max value in the block + for (int i = block_size/2; i > 0; i >>= 1) { + if (tid < i) { + buf[tid] = max(buf[tid], buf[tid + i]); + } + __syncthreads(); + } + + float tmp = 0.f; + + for (int col = tid; col < ncols; col += block_size) { + const int ix = rowx*ncols + col; + const int iy = rowy*ncols + col; + const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - buf[0]); + tmp += val; + dst[ix] = val; + } + + __syncthreads(); + + buf[tid] = tmp; + + __syncthreads(); + + // sum up partial sums + for (int i = block_size/2; i > 0; i >>= 1) { + if (tid < i) { + buf[tid] += buf[tid + i]; + } + __syncthreads(); + } + + const float inv_tmp = 1.f / buf[0]; + + for (int col = tid; col < ncols; col += block_size) { + const int i = rowx*ncols + col; + dst[i] *= inv_tmp; + } +} + static __global__ void scale_f32(const float * x, float * dst, const float scale, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; @@ -5796,7 +5858,9 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols } static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) { - const dim3 block_dims(1, WARP_SIZE, 1); + int nth = WARP_SIZE; + while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; + const dim3 block_dims(nth , 1, 1); const dim3 block_nums(nrows_x, 1, 1); soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); }