From 888a724b53faa0e651f3856d741e40184eb41cd9 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Wed, 6 Mar 2024 00:23:36 -0500 Subject: [PATCH 01/12] cleanup FA implementation + flash decoding kernel (wip) --- ggml-cuda.cu | 864 ++++++++++++++++++++++++++++++++++++++++++++++++++- ggml.c | 329 +++++++++++++++++++- ggml.h | 20 ++ llama.cpp | 80 ++++- 4 files changed, 1258 insertions(+), 35 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 72bcec8cdb17a..3f5bfd0a17456 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -114,6 +114,7 @@ #include #include #include +#include #if CUDART_VERSION < 11020 #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED @@ -720,7 +721,6 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { return a; } -#ifdef GGML_CUDA_F16 static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL #pragma unroll @@ -733,7 +733,6 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { NO_DEVICE_CODE; #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL } -#endif // GGML_CUDA_F16 static __device__ __forceinline__ float warp_reduce_max(float x) { #pragma unroll @@ -743,18 +742,18 @@ static __device__ __forceinline__ float warp_reduce_max(float x) { return x; } -//static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { -//#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX -//#pragma unroll -// for (int mask = 16; mask > 0; mask >>= 1) { -// x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); -// } -// return x; -//#else -// (void) x; -// NO_DEVICE_CODE; -//#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX -//} +static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); + } + return x; +#else + (void) x; + NO_DEVICE_CODE; +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +} static __device__ __forceinline__ float op_repeat(const float a, const float b) { return b; @@ -7080,6 +7079,676 @@ static __global__ void pool2d_nchw_kernel( o_ptr[cur_oh * ow + cur_ow] = res; } + +#if __CUDA_ARCH__ >= CC_VOLTA +typedef nvcuda::wmma::fragment half16x16_a; +typedef nvcuda::wmma::fragment half16x16_b; +typedef nvcuda::wmma::fragment half16x16_bT; +typedef nvcuda::wmma::fragment half16x16_acc; +typedef nvcuda::wmma::fragment float16x16_acc; +#endif + +// based on metal version +template // D head size, Q queries per block, C cache items per block +static __global__ void flash_attn_ext_f16( + const char* __restrict__ q, + const char* __restrict__ k, + const char* __restrict__ v, + const char* __restrict__ mask, + float* __restrict__ dst, + float scale, + int ne00, + int ne01, + int ne02, + int ne03, + int ne10, + int ne11, + int ne12, + int ne13, + int ne31, + int nb31, + int nb01, + int nb02, + int nb03, + int nb11, + int nb12, + int nb13, + int ne0, + int ne1, + int ne2, + int ne3) { +#if __CUDA_ARCH__ >= CC_VOLTA + const int warp_id = threadIdx.y; + const int lane_id = threadIdx.x; + + const int num_warps = blockDim.y; // number of warps + const int iq3 = blockIdx.z; + const int iq2 = blockIdx.y; + const int iq1 = blockIdx.x * Q; + + const int D16 = D/16; + const int Q16 = Q/16; + const int C16 = C/16; + + const int NW = WARP_SIZE; + const int SH = (C + Q); // shared memory per simdgroup in (half) + + const int T = D + num_warps*SH; // shared memory size per query in (half) + const int T2 = T/2; // shared memory size per query in (half2) + const int C2 = C/2; + const int D2 = D/2; + + extern __shared__ half __flash_attn_f16_shmem[]; + // pq + half * sq = (half *) (__flash_attn_f16_shmem + 0*D); // holds the query data + half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); // same as above but in half2 + half * ss = (half *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // scratch buffer for attention and diagonal matrix + half2 * ss2 = (half2 *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // same as above but in half2 + + half16x16_acc zr; + half16x16_acc lo[Q16][D16]; + + // load heads from Q to shared memory +#pragma unroll + for (int j0 = 0; j0 < Q; j0 += num_warps) { + const int j = j0 + warp_id; + if (j >= Q) { + break; + } + + const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); + +#pragma unroll + for (int i0 = 0; i0 < D2; i0 += NW) { + const int i = i0 + lane_id; + if (i >= D2) { + break; + } + + if (iq1 + j < ne01) { + sq2[j*T2 + i] = __float22half2_rn(q2[i]); + } else { + sq2[j*T2 + i] = make_half2(0.0, 0.0); + } + } + } + + nvcuda::wmma::fill_fragment(zr, 0.0); + + // zero out lo + for (int j = 0; j < Q16; ++j) { + for (int i = 0; i < D16; ++i) { + nvcuda::wmma::fill_fragment(lo[j][i], 0.0); + } + } + + // zero out shared memory SH + for (int j = 0; j < Q; ++j) { + for (int i0 = 0; i0 < SH; i0 += NW) { + const int i = i0 + lane_id; + if (i >= SH) { + break; + } + + ss[j*T + i] = 0.0; + } + } + + __syncthreads(); + + { + half S = __float2half(0.0f); + half M[Q]; + + for (int i = 0; i < Q; ++i) { + M[i] = __float2half(-INFINITY); + } + + // assume K and V are same shape + const int ne22 = ne12; + const int ne23 = ne13; + + const int nb21 = nb11; + const int nb22 = nb12; + const int nb23 = nb13; + + // broadcast + const int rk2 = ne02/ne12; + const int rk3 = ne03/ne13; + + const int rv2 = ne02/ne22; + const int rv3 = ne03/ne23; + + // k indices + const int ik2 = iq2 / rk2; + const int ik3 = iq3 / rk3; + + // v indices + const int iv2 = iq2 / rv2; + const int iv3 = iq3 / rv3; + + // load the queries from shared memory into local memory + half16x16_a mq[Q16][D16]; + for (int j = 0; j < Q16; ++j) { + for (int i = 0; i < D16; ++i) { + nvcuda::wmma::load_matrix_sync(mq[j][i], sq + 16*j*T + i*16, T); + } + } + + // pointer to the mask + const half * mp = mask ? (const half *) (mask + iq1*nb31) : nullptr; + + // prepare diagonal scale matrix + half16x16_b mscale; + for (int i = 0; i < 16; ++i) { + ss[i*T + i] = __float2half(scale); + } + nvcuda::wmma::load_matrix_sync(mscale, ss, T); + + // loop over the KV cache + // each simdgroup handles blocks of Q rows and C columns + for (int ic0 = 0; ic0 < ne11; ic0 += C*num_warps) { + const int ic = ic0 + warp_id*C; + if (ic >= ne11) { + break; + } + + // Q*K^T + { +#pragma unroll + for (int cc = 0; cc < C16; ++cc) { + half16x16_acc mqk[Q16]; + for (int j = 0; j < Q16; ++j) { + nvcuda::wmma::fill_fragment(mqk[j], 0); + } + + const half * pk = (const half *) ((const char *) k + ((ic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13)); + + for (int i = 0; i < D16; ++i) { + half16x16_bT mk; // transposed key + nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half)); + + for (int j = 0; j < Q16; ++j) { + nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]); + } + } + + // mqk = mqk*scale + mask + for (int j = 0; j < Q16; ++j) { + half16x16_a mqka; + half16x16_acc mm; + + if (mp) { + nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major); + } + + // convert accumulator to matrix_a + nvcuda::wmma::store_matrix_sync( ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major); + nvcuda::wmma::load_matrix_sync (mqka, ss + 16*j*T + 16*cc, T); + + nvcuda::wmma::mma_sync(mqk[j], mqka, mscale, mp ? mm : zr); + nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major); + } + } + } + + // used to detect blocks full of -INF + half2 smax = make_half2(-INFINITY, -INFINITY); + + // online softmax + for (int j = 0; j < Q; ++j) { + const half m = M[j]; + + for (int p0 = 0; p0 < C2; p0 += NW) { + const int p = p0 + lane_id; + + const half2 s = ss2[j*T2 + p]; + + smax = __hmax2(smax, s); + M[j] = __hmax(M[j], __hmax(s.x, s.y)); + } + + M[j] = warp_reduce_max(M[j]); + + // local sum + half2 ls = make_half2(0.0f, 0.0f); + half2 M2 = make_half2(M[j], M[j]); + + for (int p0 = 0; p0 < C2; p0 += NW) { + const int p = p0 + lane_id; + + const half2 s = ss2[j*T2 + p]; + + const half2 vs = h2exp(s - M2); + + ls += vs; + + // the P matrix from the paper (Q rows, C columns) + ss2[j*T2 + p] = vs; + } + + ls = warp_reduce_sum(ls); + + const half ms = hexp(m - M[j]); + + // create a QxQ diagonal matrix for rescaling the output + if (lane_id == j) { + ss[j*T + C + j] = ms; + + S = S*ms + ls.x + ls.y; + } + } + + smax = warp_reduce_max(smax); + + // skip -INF blocks + if (__hisinf(smax.x) == -1 && __hisinf(smax.y) == -1) { + continue; + } + + // O = diag(ms)*O + for (int j = 0; j < Q16; ++j) { + half16x16_a mm; + half16x16_b lob; + + nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T); + + for (int i = 0; i < D16; ++i) { + // convert accumulator to matrix_b + nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major); + nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T); + + nvcuda::wmma::mma_sync(lo[j][i], mm, lob, zr); + } + } + + // restore zeros + for (int j = 0; j < Q16; ++j) { + nvcuda::wmma::store_matrix_sync(ss + 16*j*T + C + 16*j, zr, T, nvcuda::wmma::mem_row_major); + } + + // O = O + (Q*K^T)*V + { + for (int cc = 0; cc < C16; ++cc) { + const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23)); + + half16x16_b mv[D16]; + for (int i = 0; i < D16; ++i) { + nvcuda::wmma::load_matrix_sync(mv[i], pv + i*16, nb21/sizeof(half)); + } + + half16x16_a ms[Q16]; + for (int j = 0; j < Q16; ++j) { + nvcuda::wmma::load_matrix_sync(ms[j], ss + 16*j*T + 16*cc, T); + } + + for (int j = 0; j < Q16; ++j) { + for (int i = 0; i < D16; ++i) { + nvcuda::wmma::mma_sync(lo[j][i], ms[j], mv[i], lo[j][i]); + } + } + } + } + } + + // these are needed for reducing the results from the simdgroups (reuse the ss buffer) + if (lane_id < Q) { + ss[lane_id*T + 0] = S; + ss[lane_id*T + 1] = M[lane_id]; + } + } + + // reduce the warps sequentially + for (int sg = 1; sg < num_warps; ++sg) { + __syncthreads(); + + // each simdgroup stores its output to shared memory, reusing sq + if (warp_id == sg) { + for (int j = 0; j < Q16; ++j) { + for (int i = 0; i < D16; ++i) { + nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); + } + } + } + + __syncthreads(); + + // the first simdgroup accumulates the results from the other simdgroups + if (warp_id == 0) { + for (int j = lane_id; j < Q; j += NW) { + const half S0 = ss[j*T + 0]; + const half S1 = ss[j*T + sg*SH + 0]; + + const half M0 = ss[j*T + 1]; + const half M1 = ss[j*T + sg*SH + 1]; + + const half M = __hmax(M0, M1); + + const half ms0 = hexp(M0 - M); + const half ms1 = hexp(M1 - M); + + const half S = S0*ms0 + S1*ms1; + + ss[j*T + 0] = S; + ss[j*T + 1] = M; + + ss[j*T + C + j ] = ms0; + ss[j*T + C + j + sg*SH] = ms1; + } + + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + for (int j = 0; j < Q16; ++j) { + half16x16_a ms0; + half16x16_a ms1; + half16x16_b t; + half16x16_acc t2; + + nvcuda::wmma::load_matrix_sync(ms0, ss + 16*j*T + C + 16*j, T); + nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T); + + for (int i = 0; i < D16; ++i) { + nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T); + nvcuda::wmma::mma_sync(t2, ms1, t, zr); + + // convert accumulator to matrix_b + nvcuda::wmma::store_matrix_sync( sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); + nvcuda::wmma::load_matrix_sync (t, sq + 16*j*T + i*16, T); + + nvcuda::wmma::mma_sync(lo[j][i], ms0, t, t2); + } + } + } + } + + // store result to shared memory (reuse sq) + if (warp_id == 0) { + for (int j = 0; j < Q16; ++j) { + for (int i = 0; i < D16; ++i) { + nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); + } + } + } + + // final rescale with 1/S and store to global memory + if (warp_id == 0) { + for (int j = 0; j < Q && iq1 + j < ne01; ++j) { + const half S = ss[j*T + 0]; + + for (int i0 = 0; i0 < D; i0 += NW) { + const int i = i0 + lane_id; + if (i >= D) { + break; + } + + dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i] / S); + } + } + } +#else + NO_DEVICE_CODE; +#endif +} + +template +__global__ void flash_attn_row( + const half* query, + half* key /* reuse key buffer for partials result */, + const half* value, + const half* mask, + int kv_size, + float scale, + int reduce_block, + int head_stride) { +#if __CUDA_ARCH__ >= CC_VOLTA + const int lane_index = threadIdx.x; + const int warp_index = threadIdx.y; + + const int warp_data_size = (head_dim*kv_tensor + 2); + + extern __shared__ char shmem[]; + half2* squery2 = (half2*)shmem; // load query buffer + half * squery = (half *)shmem; // probabilities buffer after online softmax + float* sscores = (float*)(shmem + head_dim*kv_tensor*sizeof(half)); // scores buffer after QK^T + float* warp_buffer = (float*)(shmem + head_dim*kv_tensor*sizeof(half) + (kv_block + 2)*sizeof(float) + (warp_index*warp_data_size*sizeof(float))); + const int HD2 = head_dim / 2; + + // load query with 128x2 shape (repeat row twice) + const half2* query_ = (const half2*)(query + head_dim*blockIdx.y); // shift as head +#pragma unroll + for (int j = 0; j < kv_tensor; j += num_warps) { + const int q_off = j + warp_index; + if (q_off >= kv_tensor) { + break; + } + +#pragma unroll + for (int i = 0; i < HD2; i += WARP_SIZE) { + const int h_offset = i + lane_index; + if (h_offset >= HD2) { + break; + } + squery2[q_off*HD2 + h_offset] = query_[h_offset]; + } + } + + __syncthreads(); + + { // QK^T + half16x16_a query_m; + nvcuda::wmma::load_matrix_sync(query_m, squery, 16); + half16x16_bT key_m; + float16x16_acc kq_m; + + const int kv_per_warp = kv_block / num_warps; + const int sum_diag = WMMA_K / kv_tensor; + // assert(kv_per_warp % kv_tensor == 0); + + const int kvi = warp_index*kv_per_warp; + +#pragma unroll + for (int kv = 0; kv < kv_per_warp; kv += kv_tensor) { + nvcuda::wmma::load_matrix_sync(key_m, key + head_stride*blockIdx.y + (blockIdx.x*kv_block + kvi + kv)*head_dim, 16); + nvcuda::wmma::fill_fragment(kq_m, 0.0f); + nvcuda::wmma::mma_sync(kq_m, query_m, key_m, kq_m); + nvcuda::wmma::store_matrix_sync(warp_buffer, kq_m, 16, nvcuda::wmma::mem_row_major); + + // sum diagonal + if (lane_index < kv_tensor) { + float seq = 0.0f; + const int seq_idx = kvi + kv + lane_index; +#pragma unroll + for (int d0 = 0; d0 < sum_diag; d0++) { + const int diag_idx = d0 + lane_index * sum_diag; + seq += warp_buffer[diag_idx*WMMA_M + diag_idx]; // sum diagonal + } + + // store sequence result + sscores[seq_idx] = seq*scale + __half2float(mask[blockIdx.x*kv_block + seq_idx]); // save as float for softmax + } + } + + __syncthreads(); + } + + // perform online softmax + { + const int kv_per_warp = kv_block / num_warps; + float M = -INFINITY; + + const int kvi = warp_index*kv_per_warp; + + for (int kv = lane_index*kv_tensor; kv < kv_per_warp; kv += WARP_SIZE*kv_tensor) { + M = fmaxf(M, fmaxf(sscores[kvi + kv], sscores[kvi + kv + 1])); + } + + M = warp_reduce_max(M); + + float S = 0.0f; + + for (int kv = lane_index*kv_tensor; kv < kv_per_warp; kv += WARP_SIZE*kv_tensor) { + S += expf(sscores[kvi + kv] - M); + S += expf(sscores[kvi + kv + 1] - M); + } + + S = warp_reduce_sum(S); + + if(lane_index == 0) { + warp_buffer[0] = M; + warp_buffer[1] = S; + // printf("warp index: %d, M= %.4f, S= %.4f\n", warp_index, M, S); + } + + __syncthreads(); + + // reduce warps + if(warp_index == 0 && lane_index == 0) { + float M0 = warp_buffer[0]; + float S0 = warp_buffer[1]; + + for(int w = 1; w < num_warps; w++) { + float M1 = warp_buffer[w * warp_data_size]; + float S1 = warp_buffer[w * warp_data_size + 1]; + + float M = fmaxf(M0, M1); + + float ms0 = expf(M0 - M); + float ms1 = expf(M1 - M); + + S0 = S0*ms0 + S1*ms1; + M0 = M; + } + + // printf("block M = %.4f, S= %.4f\n", M0, S0); + + // real softmax M and S for this block + sscores[kv_block] = M0; + sscores[kv_block + 1] = S0; + } + + __syncthreads(); + + // reuse shared memory padding + M = sscores[kv_block]; + + const int te_per_warp = tensor_elements / num_warps; + + const int si = warp_index*te_per_warp; + +#pragma unroll + for (int t0 = 0; t0 < te_per_warp; t0 += WARP_SIZE) { + const int tei = t0 + lane_index; + if(tei >= te_per_warp) { + break; + } + + const int sq_offset = si + tei; + squery[sq_offset] = __float2half(expf(sscores[sq_offset % kv_block] - M)); + } + + __syncthreads(); + } + + { // QK^TV + half16x16_a qk_m; + nvcuda::wmma::load_matrix_sync(qk_m, squery, 16); + half16x16_bT value_m; + float16x16_acc qkv_m; + + const int reduce_exccedent = reduce_block - gridDim.x; +#pragma unroll + for(int h0 = 0; h0 < head_dim; h0 += num_warps) { + const int hi = h0 + warp_index; + if(hi >= head_dim) { + break; + } + + const int output_offset = blockIdx.y * head_stride + hi * reduce_block; + + // `value` need to be transposed + nvcuda::wmma::load_matrix_sync(value_m, value + hi * kv_size + blockIdx.x*kv_block + blockIdx.y * head_stride, 16); + nvcuda::wmma::fill_fragment(qkv_m, 0.0f); + nvcuda::wmma::mma_sync(qkv_m, qk_m, value_m, qkv_m); + nvcuda::wmma::store_matrix_sync(warp_buffer, qkv_m, 16, nvcuda::wmma::mem_row_major); + + // sum diagonal + if (lane_index == 0) { + float hdim = 0.0f; + + for (int d = 0; d < WMMA_K; d++) { + hdim += warp_buffer[d*WMMA_M + d]; // sum diagonal + } + + // assume the key has been processed by blocks launched per head + key[output_offset + blockIdx.x] = __float2half(hdim); + key[blockIdx.y * head_stride + head_dim*reduce_block + blockIdx.x*2] = __float2half(sscores[kv_block]); // max of this kv block + key[blockIdx.y * head_stride + head_dim*reduce_block + blockIdx.x*2 + 1] = __float2half(sscores[kv_block + 1]); // sum of this kv block + + if(blockIdx.x == 0) { // just the first block will do this + for(int i = 0; i < reduce_exccedent; i ++) { + // this is a padding to perform a matrix multiplication without incorrect values + key[output_offset + gridDim.x + i] = __float2half(0.0f); + } + } + } + } + } +#else + NO_DEVICE_CODE; +#endif +} + +template +__global__ void fa_reduce(const half* red_buf, float* qkv, int kv_size, int num_kv_blocks, int reduce_block) { + const int lane_index = threadIdx.x; + const int warp_index = threadIdx.y; + + const int head_offset = head_dim * kv_size * blockIdx.x; + + extern __shared__ char shmem[]; + half * sscale = (half *)shmem; + float* sf_lse = (float*)(shmem + tensor_elements*sizeof(half)); + + // make scale 1.0 diagonal + if(warp_index == 0 && lane_index == 0) { + const int softmax_lse_offset = head_offset + head_dim*reduce_block; + float M0 = __half2float(red_buf[softmax_lse_offset]); + float S0 = __half2float(red_buf[softmax_lse_offset + 1]); + + for(int i = 1; i < num_kv_blocks; i++) { + float M1 = __half2float(red_buf[softmax_lse_offset + i*2]); + float S1 = __half2float(red_buf[softmax_lse_offset + i*2 + 1]); + + float M = fmaxf(M0, M1); + + float ms0 = expf(M0 - M); + float ms1 = expf(M1 - M); + + S0 = S0*ms0 + S1*ms1; + M0 = M; + + sscale[i*2 ] = __float2half(ms0); + sscale[i*2 + 1] = __float2half(ms1); + } + + sf_lse[0] = S0; + } + + __syncthreads(); + + const int hd_per_warp = head_dim / num_warps; + + // reduce kv blocks (very slow!!) + for(int hi = warp_index*hd_per_warp; hi < head_dim; hi += num_warps*hd_per_warp) { + for(int hdi = lane_index; hdi < hd_per_warp; hdi += WARP_SIZE) { + float hdim = __half2float(red_buf[head_offset + (hi + hdi) * reduce_block]); + for(int kv = 1; kv < num_kv_blocks; kv++) { + hdim = hdim*__half2float(sscale[kv*2]) + __half2float(red_buf[head_offset + (hi + hdi) * reduce_block + kv]) * __half2float(sscale[kv*2 + 1]); + } + qkv[blockIdx.x * head_dim + hi + lane_index] = hdim / sf_lse[0]; + } + } +} + template static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) { @@ -11270,6 +11939,164 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s } } +inline void ggml_cuda_flash_attn_ext(const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F16); + GGML_ASSERT(src2->type == GGML_TYPE_F16); + GGML_ASSERT(src3->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + GGML_ASSERT(src0->backend == GGML_BACKEND_TYPE_GPU); + GGML_ASSERT(src1->backend == GGML_BACKEND_TYPE_GPU); + GGML_ASSERT(src2->backend == GGML_BACKEND_TYPE_GPU); + GGML_ASSERT(src3->backend == GGML_BACKEND_TYPE_GPU); + GGML_ASSERT(dst->backend == GGML_BACKEND_TYPE_GPU); + + GGML_TENSOR_BINARY_OP_LOCALS; + + GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16); + GGML_ASSERT(!src3 || src3->backend == GGML_BACKEND_TYPE_GPU); + GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(ne01, 16) && + "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); + + ggml_cuda_set_device(g_main_device); + const cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; + + ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; + ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; + ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) src2->extra; + ggml_tensor_extra_gpu * src3_extra = src3 ? (ggml_tensor_extra_gpu *) src3->extra : nullptr; + ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; + + float scale; + memcpy(&scale, dst->op_params, sizeof(float)); + +#define NQPB 16 +#define NCPW 128 + + const int nqpb = NQPB; // queries per block + const int ncpw = NCPW; // cache values per warp (does not work for other values) + + GGML_ASSERT(NQPB <= 32); + + const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much? + // TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why + const int nwarps = ne01 <= nqpb ? std::max(2, std::min((int) ne11/ncpw, nwarps_max)) : 1; + + dim3 blocks_num((ne01 + nqpb - 1) / nqpb, ne02, ne03); + dim3 block_dim(32, nwarps, 1); + + const size_t shmem = nqpb*(ne00 + nwarps*(ncpw + nqpb))*(sizeof(float)/2); + + // increase shared memory limit to 96KB + //const size_t shmem_max = 96*1024; + //cudaFuncSetAttribute(flash_attn_ext_f16<128, NQPB, NCPW>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_max); + + switch (ne00) { + case 64: + flash_attn_ext_f16<64, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + src3 ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + src3 ? src3->ne[1] : 0, src3 ? src3->nb[1] : 0, + nb01, nb02, nb03, + nb11, nb12, nb13, + ne0, ne1, ne2, ne3); + break; + case 80: + flash_attn_ext_f16<80, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + src3 ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + src3 ? src3->ne[1] : 0, src3 ? src3->nb[1] : 0, + nb01, nb02, nb03, + nb11, nb12, nb13, + ne0, ne1, ne2, ne3); + break; + case 96: + flash_attn_ext_f16<96, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + src3 ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + src3 ? src3->ne[1] : 0, src3 ? src3->nb[1] : 0, + nb01, nb02, nb03, + nb11, nb12, nb13, + ne0, ne1, ne2, ne3); + break; + case 112: + flash_attn_ext_f16<112, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + src3 ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + src3 ? src3->ne[1] : 0, src3 ? src3->nb[1] : 0, + nb01, nb02, nb03, + nb11, nb12, nb13, + ne0, ne1, ne2, ne3); + break; + case 128: + flash_attn_ext_f16<128, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + src3 ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + src3 ? src3->ne[1] : 0, src3 ? src3->nb[1] : 0, + nb01, nb02, nb03, + nb11, nb12, nb13, + ne0, ne1, ne2, ne3); + break; + case 256: + flash_attn_ext_f16<256, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + src3 ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + src3 ? src3->ne[1] : 0, src3 ? src3->nb[1] : 0, + nb01, nb02, nb03, + nb11, nb12, nb13, + ne0, ne1, ne2, ne3); + break; + default: + break; + } + + CUDA_CHECK(cudaGetLastError()); +} + + static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale); } @@ -11565,6 +12392,8 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st case GGML_OP_ARGSORT: func = ggml_cuda_argsort; break; + case GGML_OP_FLASH_ATTN_EXT: + break; default: return false; } @@ -11579,7 +12408,11 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return true; } - func(tensor->src[0], tensor->src[1], tensor); + if(tensor->op == GGML_OP_FLASH_ATTN_EXT) { + ggml_cuda_flash_attn_ext(tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor); + } else { + func(tensor->src[0], tensor->src[1], tensor); + } return true; } @@ -12399,6 +13232,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_ARANGE: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_LEAKY_RELU: + case GGML_OP_FLASH_ATTN_EXT: return true; default: return false; diff --git a/ggml.c b/ggml.c index 6a10bbcb45e45..b83b49361b1ce 100644 --- a/ggml.c +++ b/ggml.c @@ -885,7 +885,7 @@ inline static float vaddvq_f32(float32x4_t v) { #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO #define GGML_F16_VEC_SET1 GGML_F16x8_SET1 #define GGML_F16_VEC_LOAD(p, i) GGML_F16x8_LOAD(p) - #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE(p, r[i]) + #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE((__fp16 *)(p), r[i]) #define GGML_F16_VEC_FMA GGML_F16x8_FMA #define GGML_F16_VEC_ADD GGML_F16x8_ADD #define GGML_F16_VEC_MUL GGML_F16x8_MUL @@ -911,7 +911,7 @@ inline static float vaddvq_f32(float32x4_t v) { #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1 #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p) - #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i]) + #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE((__fp16 *)(p), r[i]) #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL @@ -1501,6 +1501,37 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float #endif } +inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, const ggml_fp16_t * restrict x, const float v) { +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); + + GGML_F16_VEC ax[GGML_F16_ARR]; + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx); + + GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v); + } +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v); + } +#endif +} + // xs and vs are byte strides of x and v inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) { @@ -1585,6 +1616,35 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { #endif } +inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) { +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); + + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_MUL(ay[j], vx); + + GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v); + } +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v); + } +#endif +} + inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); } inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; } inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); } @@ -1839,6 +1899,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "LEAKY_RELU", "FLASH_ATTN", + "FLASH_ATTN_EXT", "FLASH_FF", "FLASH_ATTN_BACK", "WIN_PART", @@ -1863,7 +1924,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74"); +static_assert(GGML_OP_COUNT == 75, "GGML_OP_COUNT != 75"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1927,6 +1988,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "leaky_relu(x)", "flash_attn(x)", + "flash_attn_ext(x)", "flash_ff(x)", "flash_attn_back(x)", "win_part(x)", @@ -1951,7 +2013,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74"); +static_assert(GGML_OP_COUNT == 75, "GGML_OP_COUNT != 75"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -6041,6 +6103,59 @@ struct ggml_tensor * ggml_flash_attn( return result; } +// ggml_flash_attn_ext + +struct ggml_tensor * ggml_flash_attn_ext( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * mask, + float scale) { + GGML_ASSERT(ggml_can_mul_mat(k, q)); + // TODO: check if vT can be multiplied by (k*qT) + if (mask) { + GGML_ASSERT(ggml_is_contiguous(mask)); + GGML_ASSERT(mask->ne[2] == 1); + GGML_ASSERT(mask->ne[3] == 1); + GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) && + "the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big"); + //GGML_ASSERT(ggml_can_repeat_rows(mask, qk)); + } + + bool is_node = false; + + if (q->grad || k->grad || v->grad) { + is_node = true; + } + + // permute(0, 2, 1, 3) + int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, ne); + + float params[] = { scale }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_FLASH_ATTN_EXT; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = q; + result->src[1] = k; + result->src[2] = v; + result->src[3] = mask; + + return result; +} + +void ggml_flash_attn_ext_set_prec( + struct ggml_tensor * a, + enum ggml_prec prec) { + GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT); + + const int32_t prec_i32 = (int32_t) prec; + + ggml_set_op_params_i32(a, 1, prec_i32); // scale is on first pos +} + // ggml_flash_ff struct ggml_tensor * ggml_flash_ff( @@ -14233,6 +14348,198 @@ static void ggml_compute_forward_flash_attn( } } + +// ggml_compute_forward_flash_attn_ext + +static void ggml_compute_forward_flash_attn_ext_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * q, + const struct ggml_tensor * k, + const struct ggml_tensor * v, + const struct ggml_tensor * mask, + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t D = neq0; + const int64_t N = neq1; + + GGML_ASSERT(ne0 == D); + GGML_ASSERT(ne2 == N); + + GGML_ASSERT(nbq0 == sizeof(float)); + GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t)); + + GGML_ASSERT(neq0 == D); + GGML_ASSERT(nek0 == D); + GGML_ASSERT(nev0 == D); + + GGML_ASSERT(neq1 == N); + GGML_ASSERT(nev0 == D); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + // broadcast factors + const int64_t rk2 = neq2/nek2; + const int64_t rk3 = neq3/nek3; + + const int64_t rv2 = neq2/nev2; + const int64_t rv3 = neq3/nev3; + + if (params->type == GGML_TASK_TYPE_INIT) { + return; + } + + if (params->type == GGML_TASK_TYPE_FINALIZE) { + return; + } + + // parallelize by q rows using ggml_vec_dot_f32 + + // total rows in q + const int nr = neq1*neq2*neq3; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + float scale = 1.0f; + memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + + // loop over n_batch and n_head + for (int ir = ir0; ir < ir1; ++ir) { + // q indices + const int iq3 = ir/(neq2*neq1); + const int iq2 = (ir - iq3*neq2*neq1)/neq1; + const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); + + float S = 0.0f; + float M = -INFINITY; + + float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32); + ggml_fp16_t * Q16 = (ggml_fp16_t *) (V32); // reuse memory + ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D); + + memset(V16, 0, D*sizeof(ggml_fp16_t)); + + const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL; + + // k indices + const int ik3 = iq3 / rk3; + const int ik2 = iq2 / rk2; + + // v indices + const int iv3 = iq3 / rv3; + const int iv2 = iq2 / rv2; + + // online softmax / attention + // loop over n_kv and n_head_kv + // ref: https://arxiv.org/pdf/2112.05682.pdf + for (int64_t ic = 0; ic < nek1; ++ic) { + const float mv = mp ? GGML_FP16_TO_FP32(mp[ic]) : 0.0f; + if (mv == -INFINITY) { + continue; + } + + float s; + + // convert Q to F16 in V32 + { + const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)); + + for (int64_t d = 0; d < D; ++d) { + Q16[d] = GGML_FP32_TO_FP16(pq[d]); + } + } + + ggml_vec_dot_f16(D, + &s, 0, + (ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), 0, + Q16, 0, 1); + + s = s*scale + mv; + + const float Mold = M; + + float ms = 1.0f; + float vs = 1.0f; + + if (s > M) { + M = s; + ms = expf(Mold - M); + + // V = V*expf(Mold - M) + ggml_vec_scale_f16(D, V16, ms); + } else { + vs = expf(s - M); + } + + const ggml_fp16_t * v16 = (const ggml_fp16_t *) ((char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3)); + + // V += v*expf(s - M) + ggml_vec_mad_f16(D, V16, v16, vs); + + S = S*ms + vs; + } + + // V /= S + for (int64_t d = 0; d < D; ++d) { + V32[d] = GGML_FP16_TO_FP32(V16[d])/S; + } + + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + // original + //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float)); + + // permute(0, 2, 1, 3) + memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, V32, nb1); + } +} + +static void ggml_compute_forward_flash_attn_ext( + const struct ggml_compute_params * params, + const struct ggml_tensor * q, + const struct ggml_tensor * k, + const struct ggml_tensor * v, + const struct ggml_tensor * mask, + struct ggml_tensor * dst) { + switch (dst->op_params[1]) { + case GGML_PREC_DEFAULT: + { + ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst); + } break; + default: + { + // TODO: implement F32 precision + GGML_ASSERT(false); + } break; + } +} + // ggml_compute_forward_flash_ff static void ggml_compute_forward_flash_ff_f16( @@ -15816,6 +16123,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm const bool masked = t != 0; ggml_compute_forward_flash_attn(params, masked, tensor); } break; + case GGML_OP_FLASH_ATTN_EXT: + { + ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor); + } break; case GGML_OP_FLASH_FF: { ggml_compute_forward_flash_ff(params, tensor); @@ -17579,6 +17890,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { { n_tasks = n_threads; } break; + case GGML_OP_FLASH_ATTN_EXT: + { + n_tasks = n_threads; + } break; case GGML_OP_FLASH_FF: { n_tasks = n_threads; @@ -17974,6 +18289,12 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2 } } break; + case GGML_OP_FLASH_ATTN_EXT: + { + const int64_t ne00 = node->src[0]->ne[0]; // D + + cur = 2*sizeof(float)*ne00*n_tasks; // 2x head size + } break; case GGML_OP_FLASH_FF: { if (node->src[1]->type == GGML_TYPE_F32) { diff --git a/ggml.h b/ggml.h index 0ea4f8847795c..60a594431e927 100644 --- a/ggml.h +++ b/ggml.h @@ -470,6 +470,7 @@ extern "C" { GGML_OP_LEAKY_RELU, GGML_OP_FLASH_ATTN, + GGML_OP_FLASH_ATTN_EXT, GGML_OP_FLASH_FF, GGML_OP_FLASH_ATTN_BACK, GGML_OP_WIN_PART, @@ -1712,6 +1713,25 @@ extern "C" { struct ggml_tensor * v, bool masked); +#define GGML_KQ_MASK_PAD 32 + + // q: [n_embd, n_batch, n_head, 1] + // k: [n_embd, n_kv, n_head_kv, 1] + // v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !! + // mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !! + // res: [n_embd, n_head, n_batch, 1] !! permuted !! + GGML_API struct ggml_tensor * ggml_flash_attn_ext( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * mask, + float scale); + + GGML_API void ggml_flash_attn_ext_set_prec( + struct ggml_tensor * a, + enum ggml_prec prec); + GGML_API struct ggml_tensor * ggml_flash_attn_back( struct ggml_context * ctx, struct ggml_tensor * q, diff --git a/llama.cpp b/llama.cpp index e9192b4fa60dc..4dd5d8b26e359 100644 --- a/llama.cpp +++ b/llama.cpp @@ -104,6 +104,7 @@ #define LLAMA_MAX_NODES 8192 #define LLAMA_MAX_EXPERTS 8 +#define LLAMA_FLASH_ATTN // // logging @@ -4828,23 +4829,34 @@ static void llm_build_kv_store( const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); - // compute the transposed [n_tokens, n_embd] V matrix - struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens)); - //struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed - cb(v_cur_t, "v_cur_t", il); - struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa, (ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head); cb(k_cache_view, "k_cache_view", il); + // important: storing RoPE-ed version of K in the KV cache! + ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view)); + +#if defined(LLAMA_FLASH_ATTN) + // NOTE: the V cache is not transposed when using FLASH attention !! + struct ggml_tensor * v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa, + (ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa))*kv_head); + cb(v_cache_view, "v_cache_view", il); + + ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view)); + + GGML_UNUSED(n_ctx); +#else + // compute the transposed [n_tokens, n_embd] V matrix + //struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens)); + struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed + cb(v_cur_t, "v_cur_t", il); + struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa, ( n_ctx)*ggml_element_size(kv.v_l[il]), (kv_head)*ggml_element_size(kv.v_l[il])); - cb(v_cache_view, "v_cache_view", il); - // important: storing RoPE-ed version of K in the KV cache! - ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view)); ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur_t, v_cache_view)); +#endif } static struct ggml_tensor * llm_build_norm( @@ -5005,6 +5017,34 @@ static struct ggml_tensor * llm_build_kqv( 0); cb(k, "k", il); + struct ggml_tensor * cur; + +#if defined(LLAMA_FLASH_ATTN) + GGML_UNUSED(model); + GGML_UNUSED(n_ctx); + + GGML_ASSERT(kq_pos == nullptr && "ALiBi is not yet supported with Flash Attention"); + + // split cached v into n_head heads (not transposed) + struct ggml_tensor * v = + ggml_view_3d(ctx, kv.v_l[il], + n_embd_head_v, n_kv, n_head_kv, + ggml_row_size(kv.v_l[il]->type, n_embd_k_gqa), + ggml_row_size(kv.v_l[il]->type, n_embd_head_k), + 0); + cb(v, "v", il); + + cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale); + + ggml_flash_attn_ext_set_prec(cur, GGML_PREC_DEFAULT); + //printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]); + //printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]); + //printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]); + //printf("m: %4d %4d %4d %4d\n", kq_mask->ne[0], kq_mask->ne[1], kq_mask->ne[2], kq_mask->ne[3]); + //printf("r: %4d %4d %4d %4d\n", kqv->ne[0], kqv->ne[1], kqv->ne[2], kqv->ne[3]); + + cur = ggml_reshape_2d(ctx, cur, n_embd_head_k*n_head, n_tokens); +#else struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); cb(kq, "kq", il); @@ -5014,8 +5054,8 @@ static struct ggml_tensor * llm_build_kqv( ggml_mul_mat_set_prec(kq, GGML_PREC_F32); } -#if defined(GGML_USE_KOMPUTE) -#pragma message("TODO: ALiBi support in ggml_soft_max_ext is not implemented for Kompute") +#if defined(GGML_USE_VULKAN) || defined(GGML_USE_KOMPUTE) +#pragma message("TODO: ALiBi support in ggml_soft_max_ext is not implemented for Vulkan, and Kompute") #pragma message(" Falling back to ggml_alibi(). Will become an error in Mar 2024") #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5488") if (hparams.f_max_alibi_bias > 0.0f) { @@ -5037,7 +5077,7 @@ static struct ggml_tensor * llm_build_kqv( cb(kq, "kq_soft_max_ext", il); } - // split cached v into n_head heads + // split cached v into n_head heads (transposed) struct ggml_tensor * v = ggml_view_3d(ctx, kv.v_l[il], n_kv, n_embd_head_v, n_head_kv, @@ -5052,8 +5092,9 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); cb(kqv_merged, "kqv_merged", il); - struct ggml_tensor * cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens); + cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens); cb(cur, "kqv_merged_cont", il); +#endif ggml_build_forward_expand(graph, cur); @@ -5299,7 +5340,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); for (int il = 0; il < n_layer; ++il) { @@ -5478,7 +5519,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); cb(KQ_mask, "KQ_mask", -1); // positions of the tokens in the KV cache @@ -8268,7 +8309,8 @@ static int llama_decode_internal( // a heuristic, to avoid attending the full cache if it is not yet utilized // after enough generations, the benefit from this heuristic disappears // if we start defragmenting the cache, the benefit from this will be more important - kv_self.n = std::min(cparams.n_ctx, std::max(32u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32))); + // note: we pad the n_kv because certain GPU kernels require it (e.g. ggml_flash_attn_ext) + kv_self.n = std::min(cparams.n_ctx, std::max(128u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 128))); //kv_self.n = llama_kv_cache_cell_max(kv_self); } @@ -12134,7 +12176,10 @@ struct llama_context * llama_new_context_with_model( const auto & hparams = model->hparams; auto & cparams = ctx->cparams; - cparams.n_batch = params.n_batch; + // the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask + // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext) + cparams.n_batch = std::max((uint32_t) GGML_KQ_MASK_PAD, params.n_batch); + cparams.n_threads = params.n_threads; cparams.n_threads_batch = params.n_threads_batch; cparams.yarn_ext_factor = params.yarn_ext_factor; @@ -12352,6 +12397,9 @@ struct llama_context * llama_new_context_with_model( ggml_set_name(ctx->inp_cls, "inp_cls"); ctx->buf_input = ggml_backend_alloc_ctx_tensors_from_buft(ctx->ctx_input, llama_default_buffer_type_cpu(true)); + // zero-out the input buffer to prevent NaNs in padded tensors + ggml_backend_buffer_clear(ctx->buf_input, 0); + LLAMA_LOG_INFO("%s: %10s input buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(ctx->buf_input), ggml_backend_buffer_get_size(ctx->buf_input) / 1024.0 / 1024.0); From 936cea0370ecf2627c524280906ec8e79c74ab1f Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Thu, 7 Mar 2024 12:22:15 -0500 Subject: [PATCH 02/12] fix NaNs when context reset --- ggml-cuda.cu | 7 ++++++- llama.cpp | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 3f5bfd0a17456..2e3464b886235 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -7317,6 +7317,11 @@ static __global__ void flash_attn_ext_f16( for (int p0 = 0; p0 < C2; p0 += NW) { const int p = p0 + lane_id; + if(__hisinf(M[j]) == -1) { + ss2[j*T2 + p] = ls; + continue; + } + const half2 s = ss2[j*T2 + p]; const half2 vs = h2exp(s - M2); @@ -7332,7 +7337,7 @@ static __global__ void flash_attn_ext_f16( const half ms = hexp(m - M[j]); // create a QxQ diagonal matrix for rescaling the output - if (lane_id == j) { + if (lane_id == j && !__hisnan(ms)) { ss[j*T + C + j] = ms; S = S*ms + ls.x + ls.y; diff --git a/llama.cpp b/llama.cpp index 4dd5d8b26e359..de15a8281886b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -8310,7 +8310,7 @@ static int llama_decode_internal( // after enough generations, the benefit from this heuristic disappears // if we start defragmenting the cache, the benefit from this will be more important // note: we pad the n_kv because certain GPU kernels require it (e.g. ggml_flash_attn_ext) - kv_self.n = std::min(cparams.n_ctx, std::max(128u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 128))); + kv_self.n = std::min(cparams.n_ctx, std::max(256u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 256u))); //kv_self.n = llama_kv_cache_cell_max(kv_self); } From f49081216b402a8fc32e4ea929641a1d5f6fd266 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Thu, 7 Mar 2024 15:01:19 -0500 Subject: [PATCH 03/12] enable flash decoding --- ggml-cuda.cu | 281 ++++++++++++++++++++++++++++----------------------- llama.cpp | 8 +- 2 files changed, 163 insertions(+), 126 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 2e3464b886235..6c1302d1fa32b 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -7496,7 +7496,7 @@ static __global__ void flash_attn_ext_f16( template __global__ void flash_attn_row( - const half* query, + const float* query, half* key /* reuse key buffer for partials result */, const half* value, const half* mask, @@ -7518,7 +7518,7 @@ __global__ void flash_attn_row( const int HD2 = head_dim / 2; // load query with 128x2 shape (repeat row twice) - const half2* query_ = (const half2*)(query + head_dim*blockIdx.y); // shift as head + const float2* query_ = (const float2*)(query + head_dim*blockIdx.y); // shift as head #pragma unroll for (int j = 0; j < kv_tensor; j += num_warps) { const int q_off = j + warp_index; @@ -7532,7 +7532,7 @@ __global__ void flash_attn_row( if (h_offset >= HD2) { break; } - squery2[q_off*HD2 + h_offset] = query_[h_offset]; + squery2[q_off*HD2 + h_offset] = __float22half2_rn(query_[h_offset]); } } @@ -7546,7 +7546,6 @@ __global__ void flash_attn_row( const int kv_per_warp = kv_block / num_warps; const int sum_diag = WMMA_K / kv_tensor; - // assert(kv_per_warp % kv_tensor == 0); const int kvi = warp_index*kv_per_warp; @@ -7590,9 +7589,11 @@ __global__ void flash_attn_row( float S = 0.0f; - for (int kv = lane_index*kv_tensor; kv < kv_per_warp; kv += WARP_SIZE*kv_tensor) { - S += expf(sscores[kvi + kv] - M); - S += expf(sscores[kvi + kv + 1] - M); + if(M != -INFINITY) { + for (int kv = lane_index*kv_tensor; kv < kv_per_warp; kv += WARP_SIZE*kv_tensor) { + S += expf(sscores[kvi + kv] - M); + S += expf(sscores[kvi + kv + 1] - M); + } } S = warp_reduce_sum(S); @@ -11979,123 +11980,155 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * src0, const ggml_tensor #define NQPB 16 #define NCPW 128 - const int nqpb = NQPB; // queries per block - const int ncpw = NCPW; // cache values per warp (does not work for other values) - - GGML_ASSERT(NQPB <= 32); - - const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much? - // TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why - const int nwarps = ne01 <= nqpb ? std::max(2, std::min((int) ne11/ncpw, nwarps_max)) : 1; - - dim3 blocks_num((ne01 + nqpb - 1) / nqpb, ne02, ne03); - dim3 block_dim(32, nwarps, 1); - - const size_t shmem = nqpb*(ne00 + nwarps*(ncpw + nqpb))*(sizeof(float)/2); - - // increase shared memory limit to 96KB - //const size_t shmem_max = 96*1024; - //cudaFuncSetAttribute(flash_attn_ext_f16<128, NQPB, NCPW>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_max); - - switch (ne00) { - case 64: - flash_attn_ext_f16<64, NQPB, NCPW> - <<>> ( - (const char *) src0_extra->data_device[g_main_device], // Query - (const char *) src1_extra->data_device[g_main_device], // Key - (const char *) src2_extra->data_device[g_main_device], // Value - src3 ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask - (float *) dst_extra->data_device[g_main_device], // dst - scale, - ne00, ne01, ne02, ne03, - ne10, ne11, ne12, ne13, - src3 ? src3->ne[1] : 0, src3 ? src3->nb[1] : 0, - nb01, nb02, nb03, - nb11, nb12, nb13, - ne0, ne1, ne2, ne3); - break; - case 80: - flash_attn_ext_f16<80, NQPB, NCPW> - <<>> ( - (const char *) src0_extra->data_device[g_main_device], // Query - (const char *) src1_extra->data_device[g_main_device], // Key - (const char *) src2_extra->data_device[g_main_device], // Value - src3 ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask - (float *) dst_extra->data_device[g_main_device], // dst - scale, - ne00, ne01, ne02, ne03, - ne10, ne11, ne12, ne13, - src3 ? src3->ne[1] : 0, src3 ? src3->nb[1] : 0, - nb01, nb02, nb03, - nb11, nb12, nb13, - ne0, ne1, ne2, ne3); - break; - case 96: - flash_attn_ext_f16<96, NQPB, NCPW> - <<>> ( - (const char *) src0_extra->data_device[g_main_device], // Query - (const char *) src1_extra->data_device[g_main_device], // Key - (const char *) src2_extra->data_device[g_main_device], // Value - src3 ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask - (float *) dst_extra->data_device[g_main_device], // dst - scale, - ne00, ne01, ne02, ne03, - ne10, ne11, ne12, ne13, - src3 ? src3->ne[1] : 0, src3 ? src3->nb[1] : 0, - nb01, nb02, nb03, - nb11, nb12, nb13, - ne0, ne1, ne2, ne3); - break; - case 112: - flash_attn_ext_f16<112, NQPB, NCPW> - <<>> ( - (const char *) src0_extra->data_device[g_main_device], // Query - (const char *) src1_extra->data_device[g_main_device], // Key - (const char *) src2_extra->data_device[g_main_device], // Value - src3 ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask - (float *) dst_extra->data_device[g_main_device], // dst - scale, - ne00, ne01, ne02, ne03, - ne10, ne11, ne12, ne13, - src3 ? src3->ne[1] : 0, src3 ? src3->nb[1] : 0, - nb01, nb02, nb03, - nb11, nb12, nb13, - ne0, ne1, ne2, ne3); - break; - case 128: - flash_attn_ext_f16<128, NQPB, NCPW> - <<>> ( - (const char *) src0_extra->data_device[g_main_device], // Query - (const char *) src1_extra->data_device[g_main_device], // Key - (const char *) src2_extra->data_device[g_main_device], // Value - src3 ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask - (float *) dst_extra->data_device[g_main_device], // dst - scale, - ne00, ne01, ne02, ne03, - ne10, ne11, ne12, ne13, - src3 ? src3->ne[1] : 0, src3 ? src3->nb[1] : 0, - nb01, nb02, nb03, - nb11, nb12, nb13, - ne0, ne1, ne2, ne3); - break; - case 256: - flash_attn_ext_f16<256, NQPB, NCPW> - <<>> ( - (const char *) src0_extra->data_device[g_main_device], // Query - (const char *) src1_extra->data_device[g_main_device], // Key - (const char *) src2_extra->data_device[g_main_device], // Value - src3 ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask - (float *) dst_extra->data_device[g_main_device], // dst - scale, - ne00, ne01, ne02, ne03, - ne10, ne11, ne12, ne13, - src3 ? src3->ne[1] : 0, src3 ? src3->nb[1] : 0, - nb01, nb02, nb03, - nb11, nb12, nb13, - ne0, ne1, ne2, ne3); - break; - default: - break; + bool flash_decoding = true; + + if(!flash_decoding || ne00 != 128 || ne01 > 1) { + const int nqpb = NQPB; // queries per block + const int ncpw = NCPW; // cache values per warp (does not work for other values) + + GGML_ASSERT(NQPB <= 32); + + const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much? + // TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why + const int nwarps = ne01 <= nqpb ? std::max(2, std::min((int) ne11/ncpw, nwarps_max)) : 1; + + dim3 blocks_num((ne01 + nqpb - 1) / nqpb, ne02, ne03); + dim3 block_dim(32, nwarps, 1); + + const size_t shmem = nqpb*(ne00 + nwarps*(ncpw + nqpb))*(sizeof(float)/2); + + // increase shared memory limit to 96KB + //const size_t shmem_max = 96*1024; + //cudaFuncSetAttribute(flash_attn_ext_f16<128, NQPB, NCPW>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_max); + + switch (ne00) { + case 64: + flash_attn_ext_f16<64, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + src3 ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + src3 ? src3->ne[1] : 0, src3 ? src3->nb[1] : 0, + nb01, nb02, nb03, + nb11, nb12, nb13, + ne0, ne1, ne2, ne3); + break; + case 80: + flash_attn_ext_f16<80, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + src3 ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + src3 ? src3->ne[1] : 0, src3 ? src3->nb[1] : 0, + nb01, nb02, nb03, + nb11, nb12, nb13, + ne0, ne1, ne2, ne3); + break; + case 96: + flash_attn_ext_f16<96, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + src3 ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + src3 ? src3->ne[1] : 0, src3 ? src3->nb[1] : 0, + nb01, nb02, nb03, + nb11, nb12, nb13, + ne0, ne1, ne2, ne3); + break; + case 112: + flash_attn_ext_f16<112, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + src3 ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + src3 ? src3->ne[1] : 0, src3 ? src3->nb[1] : 0, + nb01, nb02, nb03, + nb11, nb12, nb13, + ne0, ne1, ne2, ne3); + break; + case 128: + flash_attn_ext_f16<128, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + src3 ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + src3 ? src3->ne[1] : 0, src3 ? src3->nb[1] : 0, + nb01, nb02, nb03, + nb11, nb12, nb13, + ne0, ne1, ne2, ne3); + break; + case 256: + flash_attn_ext_f16<256, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + src3 ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + src3 ? src3->ne[1] : 0, src3 ? src3->nb[1] : 0, + nb01, nb02, nb03, + nb11, nb12, nb13, + ne0, ne1, ne2, ne3); + break; + default: + break; + } + } else { +#define WMMA_M 16 +#define WMMA_N 16 +#define WMMA_K 16 +#define TENSOR_ELEMENTS 256 +#define KV_BLOCK_SIZE 256 + + constexpr int num_warps = 8; + constexpr int kv_per_block = KV_BLOCK_SIZE; + + // assert(kv_size % kv_per_block == 0); + dim3 grid_dim(ne11 / kv_per_block, ne02, 1); + dim3 block_dim(WARP_SIZE, num_warps, 1); + + int shmem = + ne00*2*sizeof(half) /* query buffer */ + + (kv_per_block + 2)*sizeof(float) /* scores buffer */ + + num_warps * (TENSOR_ELEMENTS + 2) * sizeof(float) /* tensor core result buffer per warp */; + + int reduce_block = ((grid_dim.x + WMMA_M - 1) / WMMA_M) * WMMA_N; + flash_attn_row<128, num_warps, 2, kv_per_block, TENSOR_ELEMENTS, WMMA_M, WMMA_N, WMMA_K><<>>( + (const float*)src0_extra->data_device[g_main_device], + (half*)src1_extra->data_device[g_main_device], + (const half*)src2_extra->data_device[g_main_device], + (const half*)src3_extra->data_device[g_main_device], ne11, scale, reduce_block, ne10*ne11); + fa_reduce<128, num_warps, TENSOR_ELEMENTS><<>>( + (const half*)src1_extra->data_device[g_main_device], + (float *)dst_extra->data_device[g_main_device], ne11, ne11 / kv_per_block, reduce_block); } CUDA_CHECK(cudaGetLastError()); diff --git a/llama.cpp b/llama.cpp index de15a8281886b..35dbd5cb75397 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5025,16 +5025,20 @@ static struct ggml_tensor * llm_build_kqv( GGML_ASSERT(kq_pos == nullptr && "ALiBi is not yet supported with Flash Attention"); + bool flash_decoding = true; + // split cached v into n_head heads (not transposed) struct ggml_tensor * v = - ggml_view_3d(ctx, kv.v_l[il], + ggml_view_3d(ctx, kv.v_l[il], n_embd_head_v, n_kv, n_head_kv, ggml_row_size(kv.v_l[il]->type, n_embd_k_gqa), ggml_row_size(kv.v_l[il]->type, n_embd_head_k), 0); cb(v, "v", il); - cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale); + cur = ggml_flash_attn_ext(ctx, q, ggml_cont(ctx, k), ggml_cont(ctx, + flash_decoding && n_tokens == 1 ? + ggml_permute(ctx, v, 1, 0, 2, 3) : v), ggml_cont(ctx, kq_mask), kq_scale); ggml_flash_attn_ext_set_prec(cur, GGML_PREC_DEFAULT); //printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]); From eecf7ee0812c268f62417d2a3d4e57ee2dd574d1 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Thu, 7 Mar 2024 17:00:44 -0500 Subject: [PATCH 04/12] optional flash-decoding --- ggml-cuda.cu | 11 ++++++----- llama.cpp | 17 +++++++++-------- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 6c1302d1fa32b..969fd7d3438f1 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -1,7 +1,7 @@ #include "ggml-cuda.h" #include "ggml.h" #include "ggml-backend-impl.h" - +#define GGML_FLASH_DECODING #include #include #include @@ -7339,7 +7339,6 @@ static __global__ void flash_attn_ext_f16( // create a QxQ diagonal matrix for rescaling the output if (lane_id == j && !__hisnan(ms)) { ss[j*T + C + j] = ms; - S = S*ms + ls.x + ls.y; } } @@ -11980,9 +11979,9 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * src0, const ggml_tensor #define NQPB 16 #define NCPW 128 - bool flash_decoding = true; - - if(!flash_decoding || ne00 != 128 || ne01 > 1) { +#ifdef GGML_FLASH_DECODING + if(ne00 != 128 || ne01 > 1) { +#endif const int nqpb = NQPB; // queries per block const int ncpw = NCPW; // cache values per warp (does not work for other values) @@ -12101,6 +12100,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * src0, const ggml_tensor default: break; } +#ifdef GGML_FLASH_DECODING } else { #define WMMA_M 16 #define WMMA_N 16 @@ -12130,6 +12130,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * src0, const ggml_tensor (const half*)src1_extra->data_device[g_main_device], (float *)dst_extra->data_device[g_main_device], ne11, ne11 / kv_per_block, reduce_block); } +#endif CUDA_CHECK(cudaGetLastError()); } diff --git a/llama.cpp b/llama.cpp index 35dbd5cb75397..9cfb624a7f452 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1,6 +1,6 @@ #define LLAMA_API_INTERNAL #include "llama.h" - +#define GGML_FLASH_DECODING #include "unicode.h" #include "ggml.h" @@ -5025,8 +5025,6 @@ static struct ggml_tensor * llm_build_kqv( GGML_ASSERT(kq_pos == nullptr && "ALiBi is not yet supported with Flash Attention"); - bool flash_decoding = true; - // split cached v into n_head heads (not transposed) struct ggml_tensor * v = ggml_view_3d(ctx, kv.v_l[il], @@ -5035,11 +5033,14 @@ static struct ggml_tensor * llm_build_kqv( ggml_row_size(kv.v_l[il]->type, n_embd_head_k), 0); cb(v, "v", il); - - cur = ggml_flash_attn_ext(ctx, q, ggml_cont(ctx, k), ggml_cont(ctx, - flash_decoding && n_tokens == 1 ? - ggml_permute(ctx, v, 1, 0, 2, 3) : v), ggml_cont(ctx, kq_mask), kq_scale); - +#ifdef GGML_FLASH_DECODING + cur = ggml_flash_attn_ext(ctx, q, + n_tokens == 1 ? ggml_cont(ctx, k) : k, + n_tokens == 1 ? ggml_cont(ctx, ggml_permute(ctx, v, 1, 0, 2, 3)) : v, + ggml_cont(ctx, kq_mask), kq_scale); +#else + cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale); +#endif ggml_flash_attn_ext_set_prec(cur, GGML_PREC_DEFAULT); //printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]); //printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]); From be9ecd6f05fe2ad35ede967cc9075b4638d2be59 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Thu, 7 Mar 2024 21:53:53 -0500 Subject: [PATCH 05/12] fix bug + debug prints --- ggml-cuda.cu | 53 ++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 51 insertions(+), 2 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 969fd7d3438f1..d88cb992440ed 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -7749,7 +7749,7 @@ __global__ void fa_reduce(const half* red_buf, float* qkv, int kv_size, int num_ for(int kv = 1; kv < num_kv_blocks; kv++) { hdim = hdim*__half2float(sscale[kv*2]) + __half2float(red_buf[head_offset + (hi + hdi) * reduce_block + kv]) * __half2float(sscale[kv*2 + 1]); } - qkv[blockIdx.x * head_dim + hi + lane_index] = hdim / sf_lse[0]; + qkv[blockIdx.x * head_dim + hi + hdi] = hdim / sf_lse[0]; } } } @@ -11944,6 +11944,44 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s } } +static void save_tensor_to_file(const char* filename, const ggml_tensor* tensor, const char* name, cudaStream_t stream) +{ + // FILE* f = fopen(filename, "wb"); + // int len = strlen(name); + // int n_dims = ggml_n_dims(tensor); + // printf("writing '%s' - %d dimens\n", name, n_dims); + // fwrite(&n_dims, sizeof(n_dims), 1, f); + printf("============== %s =================\n", name); + int ttype = (int)tensor->type; + // fwrite(&ttype, sizeof(ttype), 1, f); + // for (int i = 0; i < n_dims; ++i) { + // int ne_ = (int) tensor->ne[i]; + // fwrite(&ne_, sizeof(ne_), 1, f); + // } + // fwrite(&len, sizeof(len), 1, f); + // fwrite(name, len, 1, f); + void* data = malloc(ggml_nbytes(tensor)); + ggml_backend_tensor_get(tensor, data, 0, ggml_nbytes(tensor)); + // printf("[%d, %d] %zu\n", tensor->ne[0], tensor->ne[1], ggml_nbytes(tensor)); + for(int r = 0;r < (tensor->ne[1] > 1 ? 10 : 1); r ++) { + for(int c = 0;c < 10; c ++) { + if(ttype == GGML_TYPE_F32) { + printf("%0.5ff, ",((float*)data)[r * tensor->ne[0] + c]); + } else if(ttype == GGML_TYPE_F16) { + printf("%0.5ff, ", __half2float(((half*)data)[r * tensor->ne[0] + c])); + } + } + printf("\n"); + } + // if(tensor->ne[0] == 128 && tensor->ne[1] == 32) { + // printf("BACKTRACKING: %.4f\n", ((float*)data)[113]); + // } + // fwrite(data, ggml_nbytes(tensor), 1, f); + free(data); + // fclose(f); +} +bool debug_kernel = true; + inline void ggml_cuda_flash_attn_ext(const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F16); @@ -12108,7 +12146,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * src0, const ggml_tensor #define TENSOR_ELEMENTS 256 #define KV_BLOCK_SIZE 256 - constexpr int num_warps = 8; + constexpr int num_warps = 1; constexpr int kv_per_block = KV_BLOCK_SIZE; // assert(kv_size % kv_per_block == 0); @@ -12120,6 +12158,10 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * src0, const ggml_tensor (kv_per_block + 2)*sizeof(float) /* scores buffer */ + num_warps * (TENSOR_ELEMENTS + 2) * sizeof(float) /* tensor core result buffer per warp */; + if(ne01 == 1 && debug_kernel) { + save_tensor_to_file("C:\\proyectos\\kernel-data\\tg\\fa-cuda-k-256.tensor", src1, "Key data", main_stream); + } + int reduce_block = ((grid_dim.x + WMMA_M - 1) / WMMA_M) * WMMA_N; flash_attn_row<128, num_warps, 2, kv_per_block, TENSOR_ELEMENTS, WMMA_M, WMMA_N, WMMA_K><<>>( (const float*)src0_extra->data_device[g_main_device], @@ -12129,6 +12171,13 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * src0, const ggml_tensor fa_reduce<128, num_warps, TENSOR_ELEMENTS><<>>( (const half*)src1_extra->data_device[g_main_device], (float *)dst_extra->data_device[g_main_device], ne11, ne11 / kv_per_block, reduce_block); + if(ne01 == 1 && debug_kernel) { + save_tensor_to_file("C:\\proyectos\\kernel-data\\tg\\fa-cuda-q-256.tensor", src0, "Query data", main_stream); + save_tensor_to_file("C:\\proyectos\\kernel-data\\tg\\fa-cuda-v-256.tensor", src2, "Value data", main_stream); + save_tensor_to_file("C:\\proyectos\\kernel-data\\tg\\fa-cuda-mask-256.tensor", src3, "Mask data", main_stream); + save_tensor_to_file("C:\\proyectos\\kernel-data\\tg\\fa-cuda-qkv-256.tensor", dst, "QKV data", main_stream); + debug_kernel = false; + } } #endif From 82374d0adbd8d3378479e59722288755587a6b5e Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Fri, 8 Mar 2024 12:29:21 -0500 Subject: [PATCH 06/12] add more debug mods --- ggml-cuda.cu | 39 +++++++++++++++++++++++++++++---------- ggml.c | 36 ++++++++++++++++++++++++++++++++++++ llama.cpp | 2 +- run-only-cpu.sh | 4 ++++ run-server.sh | 11 +++++++++++ 5 files changed, 81 insertions(+), 11 deletions(-) create mode 100644 run-only-cpu.sh create mode 100644 run-server.sh diff --git a/ggml-cuda.cu b/ggml-cuda.cu index d88cb992440ed..2ac00083a3b5a 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -11944,14 +11944,14 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s } } -static void save_tensor_to_file(const char* filename, const ggml_tensor* tensor, const char* name, cudaStream_t stream) +static void save_tensor_to_file(const char* filename, const ggml_tensor* tensor, const char* name) { // FILE* f = fopen(filename, "wb"); // int len = strlen(name); // int n_dims = ggml_n_dims(tensor); // printf("writing '%s' - %d dimens\n", name, n_dims); // fwrite(&n_dims, sizeof(n_dims), 1, f); - printf("============== %s =================\n", name); + printf("============== %s [%d, %d, %d, %d] =================\n", name, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); int ttype = (int)tensor->type; // fwrite(&ttype, sizeof(ttype), 1, f); // for (int i = 0; i < n_dims; ++i) { @@ -11963,8 +11963,8 @@ static void save_tensor_to_file(const char* filename, const ggml_tensor* tensor, void* data = malloc(ggml_nbytes(tensor)); ggml_backend_tensor_get(tensor, data, 0, ggml_nbytes(tensor)); // printf("[%d, %d] %zu\n", tensor->ne[0], tensor->ne[1], ggml_nbytes(tensor)); - for(int r = 0;r < (tensor->ne[1] > 1 ? 10 : 1); r ++) { - for(int c = 0;c < 10; c ++) { + for(int r = 0;r < (tensor->ne[1] > 1 ? 16 : 1); r ++) { + for(int c = 0;c < 16; c ++) { if(ttype == GGML_TYPE_F32) { printf("%0.5ff, ",((float*)data)[r * tensor->ne[0] + c]); } else if(ttype == GGML_TYPE_F16) { @@ -11980,7 +11980,8 @@ static void save_tensor_to_file(const char* filename, const ggml_tensor* tensor, free(data); // fclose(f); } -bool debug_kernel = true; + +bool debug_kernel = true, debug_prompt = true; inline void ggml_cuda_flash_attn_ext(const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) { GGML_ASSERT(src0->type == GGML_TYPE_F32); @@ -12138,6 +12139,23 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * src0, const ggml_tensor default: break; } + if(ne01 == 1 && debug_kernel) { + printf("TOKEN GENERATION\n"); + save_tensor_to_file("C:\\proyectos\\kernel-data\\tg\\fa-cuda-q-256.tensor", src0, "Query data"); + save_tensor_to_file("C:\\proyectos\\kernel-data\\tg\\fa-cuda-k-256.tensor", src1, "Key data"); + save_tensor_to_file("C:\\proyectos\\kernel-data\\tg\\fa-cuda-v-256.tensor", src2, "Value data"); + save_tensor_to_file("C:\\proyectos\\kernel-data\\tg\\fa-cuda-mask-256.tensor", src3, "Mask data"); + save_tensor_to_file("C:\\proyectos\\kernel-data\\tg\\fa-cuda-qkv-256.tensor", dst, "QKV data"); + debug_kernel = false; + } else if(ne01 == 112 && debug_prompt) { + printf("PROMPT PROCESSING\n"); + save_tensor_to_file("C:\\proyectos\\kernel-data\\tg\\fa-cuda-q-256.tensor", src0, "Query data"); + save_tensor_to_file("C:\\proyectos\\kernel-data\\tg\\fa-cuda-k-256.tensor", src1, "Key data"); + save_tensor_to_file("C:\\proyectos\\kernel-data\\tg\\fa-cuda-v-256.tensor", src2, "Value data"); + save_tensor_to_file("C:\\proyectos\\kernel-data\\tg\\fa-cuda-mask-256.tensor", src3, "Mask data"); + save_tensor_to_file("C:\\proyectos\\kernel-data\\tg\\fa-cuda-qkv-256.tensor", dst, "QKV data"); + debug_prompt = false; + } #ifdef GGML_FLASH_DECODING } else { #define WMMA_M 16 @@ -12159,7 +12177,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * src0, const ggml_tensor num_warps * (TENSOR_ELEMENTS + 2) * sizeof(float) /* tensor core result buffer per warp */; if(ne01 == 1 && debug_kernel) { - save_tensor_to_file("C:\\proyectos\\kernel-data\\tg\\fa-cuda-k-256.tensor", src1, "Key data", main_stream); + save_tensor_to_file("C:\\proyectos\\kernel-data\\tg\\fa-cuda-k-256.tensor", src1, "Key data"); } int reduce_block = ((grid_dim.x + WMMA_M - 1) / WMMA_M) * WMMA_N; @@ -12172,10 +12190,11 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * src0, const ggml_tensor (const half*)src1_extra->data_device[g_main_device], (float *)dst_extra->data_device[g_main_device], ne11, ne11 / kv_per_block, reduce_block); if(ne01 == 1 && debug_kernel) { - save_tensor_to_file("C:\\proyectos\\kernel-data\\tg\\fa-cuda-q-256.tensor", src0, "Query data", main_stream); - save_tensor_to_file("C:\\proyectos\\kernel-data\\tg\\fa-cuda-v-256.tensor", src2, "Value data", main_stream); - save_tensor_to_file("C:\\proyectos\\kernel-data\\tg\\fa-cuda-mask-256.tensor", src3, "Mask data", main_stream); - save_tensor_to_file("C:\\proyectos\\kernel-data\\tg\\fa-cuda-qkv-256.tensor", dst, "QKV data", main_stream); + printf("TOKEN GENERATION FLASH DECODING\n"); + save_tensor_to_file("C:\\proyectos\\kernel-data\\tg\\fa-cuda-q-256.tensor", src0, "Query data"); + save_tensor_to_file("C:\\proyectos\\kernel-data\\tg\\fa-cuda-v-256.tensor", src2, "Value data"); + save_tensor_to_file("C:\\proyectos\\kernel-data\\tg\\fa-cuda-mask-256.tensor", src3, "Mask data"); + save_tensor_to_file("C:\\proyectos\\kernel-data\\tg\\fa-cuda-qkv-256.tensor", dst, "QKV data"); debug_kernel = false; } } diff --git a/ggml.c b/ggml.c index b83b49361b1ce..350c512734cb7 100644 --- a/ggml.c +++ b/ggml.c @@ -14520,6 +14520,24 @@ static void ggml_compute_forward_flash_attn_ext_f16( } } +static void print_tensor(const char* filename, const struct ggml_tensor* tensor, const char* name) +{ + printf("============== %s [%d, %d, %d, %d] =================\n", name, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); + int ttype = (int)tensor->type; + for(int r = 0;r < (tensor->ne[1] > 1 ? 16 : 1); r ++) { + for(int c = 0;c < 16; c ++) { + if(ttype == GGML_TYPE_F32) { + printf("%0.5ff, ",((float*)tensor->data)[r * tensor->ne[0] + c]); + } else if(ttype == GGML_TYPE_F16) { + printf("%0.5ff, ", ggml_fp16_to_fp32(((ggml_fp16_t*)tensor->data)[r * tensor->ne[0] + c])); + } + } + printf("\n"); + } +} + +bool debug_kernel_ = true, debug_prompt_ = true; + static void ggml_compute_forward_flash_attn_ext( const struct ggml_compute_params * params, const struct ggml_tensor * q, @@ -14538,6 +14556,24 @@ static void ggml_compute_forward_flash_attn_ext( GGML_ASSERT(false); } break; } + + if(q->ne[1] == 1 && debug_kernel_) { + printf("TOKEN GENERATION\n"); + print_tensor("C:\\proyectos\\kernel-data\\tg\\fa-cuda-q-256.tensor", q, "Query data"); + print_tensor("C:\\proyectos\\kernel-data\\tg\\fa-cuda-k-256.tensor", k, "Key data"); + print_tensor("C:\\proyectos\\kernel-data\\tg\\fa-cuda-v-256.tensor", v, "Value data"); + print_tensor("C:\\proyectos\\kernel-data\\tg\\fa-cuda-mask-256.tensor", mask, "Mask data"); + print_tensor("C:\\proyectos\\kernel-data\\tg\\fa-cuda-qkv-256.tensor", dst, "QKV data"); + debug_kernel_ = false; + } else if(q->ne[1] == 112 && debug_prompt_) { + printf("PROMPT PROCESSING\n"); + print_tensor("C:\\proyectos\\kernel-data\\tg\\fa-cuda-q-256.tensor", q, "Query data"); + print_tensor("C:\\proyectos\\kernel-data\\tg\\fa-cuda-k-256.tensor", k, "Key data"); + print_tensor("C:\\proyectos\\kernel-data\\tg\\fa-cuda-v-256.tensor", v, "Value data"); + print_tensor("C:\\proyectos\\kernel-data\\tg\\fa-cuda-mask-256.tensor", mask, "Mask data"); + print_tensor("C:\\proyectos\\kernel-data\\tg\\fa-cuda-qkv-256.tensor", dst, "QKV data"); + debug_prompt_ = false; + } } // ggml_compute_forward_flash_ff diff --git a/llama.cpp b/llama.cpp index 9cfb624a7f452..2bec80573a915 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5033,7 +5033,7 @@ static struct ggml_tensor * llm_build_kqv( ggml_row_size(kv.v_l[il]->type, n_embd_head_k), 0); cb(v, "v", il); -#ifdef GGML_FLASH_DECODING +#if defined(GGML_FLASH_DECODING) && defined(GGML_USE_CUBLAS) cur = ggml_flash_attn_ext(ctx, q, n_tokens == 1 ? ggml_cont(ctx, k) : k, n_tokens == 1 ? ggml_cont(ctx, ggml_permute(ctx, v, 1, 0, 2, 3)) : v, diff --git a/run-only-cpu.sh b/run-only-cpu.sh new file mode 100644 index 0000000000000..316874ee37770 --- /dev/null +++ b/run-only-cpu.sh @@ -0,0 +1,4 @@ +#!/bin/bash +cmake -S .. -B ../build -DLLAMA_CUBLAS=OFF +cmake --build ../build --config Release +../build/bin/server -m mixtral-instruct-8x7b-q2k.gguf --port 8081 -t 1 -c 512 --host 172.17.0.2 --log-format text diff --git a/run-server.sh b/run-server.sh new file mode 100644 index 0000000000000..968974cc37079 --- /dev/null +++ b/run-server.sh @@ -0,0 +1,11 @@ +#!/bin/bash +cd .. +apt-get -y install libssl-dev +wget https://github.com/Kitware/CMake/releases/download/v3.29.0-rc1/cmake-3.29.0-rc1.tar.gz +tar -xzvf cmake-3.29.0-rc1.tar.gz +cd cmake-3.29.0-rc1 && ./bootstrap && make -j$(nproc) && make install +cd ../llama.cpp/models +wget https://huggingface.co/ikawrakow/mixtral-instruct-8x7b-quantized-gguf/resolve/main/mixtral-instruct-8x7b-q2k.gguf +mkdir ../build && cmake -S .. -B ../build -DLLAMA_CUBLAS=ON +cmake --build ../build --config Release +../build/bin/server -m mixtral-instruct-8x7b-q2k.gguf --port 8081 -c 512 --host 172.17.0.2 -ngl 32 --log-format text From 7b979d1091a8f80474a4ca96d785ea6c03e7d7e1 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Fri, 8 Mar 2024 19:22:04 -0500 Subject: [PATCH 07/12] fix mixtral models - flash decoding --- ggml-cuda.cu | 103 +++++++++++++++++++++++++------------------------- run-server.sh | 2 +- 2 files changed, 53 insertions(+), 52 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 2ac00083a3b5a..b42f143df8845 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -7496,13 +7496,14 @@ static __global__ void flash_attn_ext_f16( template __global__ void flash_attn_row( const float* query, - half* key /* reuse key buffer for partials result */, + const half* key /* reuse key buffer for partials result */, const half* value, const half* mask, + half* tmp, int kv_size, float scale, - int reduce_block, - int head_stride) { + int head_stride, + int r_kv_heads) { #if __CUDA_ARCH__ >= CC_VOLTA const int lane_index = threadIdx.x; const int warp_index = threadIdx.y; @@ -7515,6 +7516,7 @@ __global__ void flash_attn_row( float* sscores = (float*)(shmem + head_dim*kv_tensor*sizeof(half)); // scores buffer after QK^T float* warp_buffer = (float*)(shmem + head_dim*kv_tensor*sizeof(half) + (kv_block + 2)*sizeof(float) + (warp_index*warp_data_size*sizeof(float))); const int HD2 = head_dim / 2; + const int kv_head_offset = (blockIdx.y / r_kv_heads) * head_stride; // load query with 128x2 shape (repeat row twice) const float2* query_ = (const float2*)(query + head_dim*blockIdx.y); // shift as head @@ -7545,12 +7547,11 @@ __global__ void flash_attn_row( const int kv_per_warp = kv_block / num_warps; const int sum_diag = WMMA_K / kv_tensor; - const int kvi = warp_index*kv_per_warp; #pragma unroll for (int kv = 0; kv < kv_per_warp; kv += kv_tensor) { - nvcuda::wmma::load_matrix_sync(key_m, key + head_stride*blockIdx.y + (blockIdx.x*kv_block + kvi + kv)*head_dim, 16); + nvcuda::wmma::load_matrix_sync(key_m, key + (blockIdx.x*kv_block + kvi + kv)*head_dim + kv_head_offset, 16); nvcuda::wmma::fill_fragment(kq_m, 0.0f); nvcuda::wmma::mma_sync(kq_m, query_m, key_m, kq_m); nvcuda::wmma::store_matrix_sync(warp_buffer, kq_m, 16, nvcuda::wmma::mem_row_major); @@ -7564,7 +7565,6 @@ __global__ void flash_attn_row( const int diag_idx = d0 + lane_index * sum_diag; seq += warp_buffer[diag_idx*WMMA_M + diag_idx]; // sum diagonal } - // store sequence result sscores[seq_idx] = seq*scale + __half2float(mask[blockIdx.x*kv_block + seq_idx]); // save as float for softmax } @@ -7623,8 +7623,6 @@ __global__ void flash_attn_row( M0 = M; } - // printf("block M = %.4f, S= %.4f\n", M0, S0); - // real softmax M and S for this block sscores[kv_block] = M0; sscores[kv_block + 1] = S0; @@ -7659,18 +7657,23 @@ __global__ void flash_attn_row( half16x16_bT value_m; float16x16_acc qkv_m; - const int reduce_exccedent = reduce_block - gridDim.x; + // const int qkv_block_size = gridDim.x * head_dim + gridDim.x * 2; + // const int qkv_head_offset = kv_head_offset + (blockIdx.y % r_kv_heads) * qkv_block_size; + + const int qkv_block_size = gridDim.x * head_dim + gridDim.x * 2; + const int qkv_head_offset = blockIdx.y * qkv_block_size; + #pragma unroll - for(int h0 = 0; h0 < head_dim; h0 += num_warps) { + for(int h0 = 0; h0 < head_dim; h0 += num_warps) { const int hi = h0 + warp_index; if(hi >= head_dim) { break; } - const int output_offset = blockIdx.y * head_stride + hi * reduce_block; + const int output_offset = qkv_head_offset + hi * gridDim.x; // `value` need to be transposed - nvcuda::wmma::load_matrix_sync(value_m, value + hi * kv_size + blockIdx.x*kv_block + blockIdx.y * head_stride, 16); + nvcuda::wmma::load_matrix_sync(value_m, value + hi * kv_size + blockIdx.x*kv_block + kv_head_offset, 16); nvcuda::wmma::fill_fragment(qkv_m, 0.0f); nvcuda::wmma::mma_sync(qkv_m, qk_m, value_m, qkv_m); nvcuda::wmma::store_matrix_sync(warp_buffer, qkv_m, 16, nvcuda::wmma::mem_row_major); @@ -7683,45 +7686,38 @@ __global__ void flash_attn_row( hdim += warp_buffer[d*WMMA_M + d]; // sum diagonal } - // assume the key has been processed by blocks launched per head - key[output_offset + blockIdx.x] = __float2half(hdim); - key[blockIdx.y * head_stride + head_dim*reduce_block + blockIdx.x*2] = __float2half(sscores[kv_block]); // max of this kv block - key[blockIdx.y * head_stride + head_dim*reduce_block + blockIdx.x*2 + 1] = __float2half(sscores[kv_block + 1]); // sum of this kv block - - if(blockIdx.x == 0) { // just the first block will do this - for(int i = 0; i < reduce_exccedent; i ++) { - // this is a padding to perform a matrix multiplication without incorrect values - key[output_offset + gridDim.x + i] = __float2half(0.0f); - } - } + tmp[output_offset + blockIdx.x] = __float2half(hdim); } } + + if(warp_index == 0 && lane_index == 0) { + tmp[qkv_head_offset + gridDim.x * head_dim + blockIdx.x*2] = __float2half(sscores[kv_block]); // max of this kv block + tmp[qkv_head_offset + gridDim.x * head_dim + blockIdx.x*2 + 1] = __float2half(sscores[kv_block + 1]); // sum of this kv block + } } #else NO_DEVICE_CODE; #endif } -template -__global__ void fa_reduce(const half* red_buf, float* qkv, int kv_size, int num_kv_blocks, int reduce_block) { +template +__global__ void fa_reduce(const half* partial_qkv, float* qkv, int kv_size, int num_kv_blocks, int r_kv_heads) { const int lane_index = threadIdx.x; const int warp_index = threadIdx.y; - const int head_offset = head_dim * kv_size * blockIdx.x; + const int qkv_partial_offset = blockIdx.x * (num_kv_blocks * head_dim + num_kv_blocks*2); extern __shared__ char shmem[]; - half * sscale = (half *)shmem; - float* sf_lse = (float*)(shmem + tensor_elements*sizeof(half)); + float* softmax_lse = (float *)shmem; - // make scale 1.0 diagonal if(warp_index == 0 && lane_index == 0) { - const int softmax_lse_offset = head_offset + head_dim*reduce_block; - float M0 = __half2float(red_buf[softmax_lse_offset]); - float S0 = __half2float(red_buf[softmax_lse_offset + 1]); + const int softmax_lse_offset = qkv_partial_offset + num_kv_blocks * head_dim; + float M0 = __half2float(partial_qkv[softmax_lse_offset]); + float S0 = __half2float(partial_qkv[softmax_lse_offset + 1]); for(int i = 1; i < num_kv_blocks; i++) { - float M1 = __half2float(red_buf[softmax_lse_offset + i*2]); - float S1 = __half2float(red_buf[softmax_lse_offset + i*2 + 1]); + float M1 = __half2float(partial_qkv[softmax_lse_offset + i*2]); + float S1 = __half2float(partial_qkv[softmax_lse_offset + i*2 + 1]); float M = fmaxf(M0, M1); @@ -7731,11 +7727,11 @@ __global__ void fa_reduce(const half* red_buf, float* qkv, int kv_size, int num_ S0 = S0*ms0 + S1*ms1; M0 = M; - sscale[i*2 ] = __float2half(ms0); - sscale[i*2 + 1] = __float2half(ms1); + softmax_lse[i*2 ] = ms0; + softmax_lse[i*2 + 1] = ms1; } - sf_lse[0] = S0; + softmax_lse[0] = S0; } __syncthreads(); @@ -7745,11 +7741,13 @@ __global__ void fa_reduce(const half* red_buf, float* qkv, int kv_size, int num_ // reduce kv blocks (very slow!!) for(int hi = warp_index*hd_per_warp; hi < head_dim; hi += num_warps*hd_per_warp) { for(int hdi = lane_index; hdi < hd_per_warp; hdi += WARP_SIZE) { - float hdim = __half2float(red_buf[head_offset + (hi + hdi) * reduce_block]); + const int hdim_index = hi + hdi; + const int qkv_index = qkv_partial_offset + hdim_index * num_kv_blocks; + float hdim = __half2float(partial_qkv[qkv_index]); for(int kv = 1; kv < num_kv_blocks; kv++) { - hdim = hdim*__half2float(sscale[kv*2]) + __half2float(red_buf[head_offset + (hi + hdi) * reduce_block + kv]) * __half2float(sscale[kv*2 + 1]); + hdim = hdim * softmax_lse[kv*2] + __half2float(partial_qkv[qkv_index + kv]) * softmax_lse[kv * 2 + 1]; } - qkv[blockIdx.x * head_dim + hi + hdi] = hdim / sf_lse[0]; + qkv[blockIdx.x * head_dim + hdim_index] = hdim / softmax_lse[0]; } } } @@ -11963,8 +11961,8 @@ static void save_tensor_to_file(const char* filename, const ggml_tensor* tensor, void* data = malloc(ggml_nbytes(tensor)); ggml_backend_tensor_get(tensor, data, 0, ggml_nbytes(tensor)); // printf("[%d, %d] %zu\n", tensor->ne[0], tensor->ne[1], ggml_nbytes(tensor)); - for(int r = 0;r < (tensor->ne[1] > 1 ? 16 : 1); r ++) { - for(int c = 0;c < 16; c ++) { + for(int r = 0;r < (tensor->ne[1] > 1 ? 6 : 1); r ++) { + for(int c = 0;c < 8; c ++) { if(ttype == GGML_TYPE_F32) { printf("%0.5ff, ",((float*)data)[r * tensor->ne[0] + c]); } else if(ttype == GGML_TYPE_F16) { @@ -12147,7 +12145,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * src0, const ggml_tensor save_tensor_to_file("C:\\proyectos\\kernel-data\\tg\\fa-cuda-mask-256.tensor", src3, "Mask data"); save_tensor_to_file("C:\\proyectos\\kernel-data\\tg\\fa-cuda-qkv-256.tensor", dst, "QKV data"); debug_kernel = false; - } else if(ne01 == 112 && debug_prompt) { + } else if(ne01 == 104 && debug_prompt) { printf("PROMPT PROCESSING\n"); save_tensor_to_file("C:\\proyectos\\kernel-data\\tg\\fa-cuda-q-256.tensor", src0, "Query data"); save_tensor_to_file("C:\\proyectos\\kernel-data\\tg\\fa-cuda-k-256.tensor", src1, "Key data"); @@ -12166,11 +12164,16 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * src0, const ggml_tensor constexpr int num_warps = 1; constexpr int kv_per_block = KV_BLOCK_SIZE; + int num_kv_blocks = ne11 / kv_per_block; - // assert(kv_size % kv_per_block == 0); - dim3 grid_dim(ne11 / kv_per_block, ne02, 1); + dim3 grid_dim(num_kv_blocks, ne02, 1); dim3 block_dim(WARP_SIZE, num_warps, 1); + half* tmp; + cudaMalloc((void **)&tmp, ((num_kv_blocks * ne00) + num_kv_blocks*2) * ne02 * sizeof(half)); + + const int r_kv_heads = ne02 / ne12; + int shmem = ne00*2*sizeof(half) /* query buffer */ + (kv_per_block + 2)*sizeof(float) /* scores buffer */ + @@ -12180,15 +12183,13 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * src0, const ggml_tensor save_tensor_to_file("C:\\proyectos\\kernel-data\\tg\\fa-cuda-k-256.tensor", src1, "Key data"); } - int reduce_block = ((grid_dim.x + WMMA_M - 1) / WMMA_M) * WMMA_N; flash_attn_row<128, num_warps, 2, kv_per_block, TENSOR_ELEMENTS, WMMA_M, WMMA_N, WMMA_K><<>>( (const float*)src0_extra->data_device[g_main_device], - (half*)src1_extra->data_device[g_main_device], - (const half*)src2_extra->data_device[g_main_device], - (const half*)src3_extra->data_device[g_main_device], ne11, scale, reduce_block, ne10*ne11); - fa_reduce<128, num_warps, TENSOR_ELEMENTS><<>>( (const half*)src1_extra->data_device[g_main_device], - (float *)dst_extra->data_device[g_main_device], ne11, ne11 / kv_per_block, reduce_block); + (const half*)src2_extra->data_device[g_main_device], + (const half*)src3_extra->data_device[g_main_device], + tmp, ne11, scale, ne10*ne11, r_kv_heads); + fa_reduce<128, num_warps><<>>(tmp, (float *)dst_extra->data_device[g_main_device], ne11, num_kv_blocks, r_kv_heads); if(ne01 == 1 && debug_kernel) { printf("TOKEN GENERATION FLASH DECODING\n"); save_tensor_to_file("C:\\proyectos\\kernel-data\\tg\\fa-cuda-q-256.tensor", src0, "Query data"); diff --git a/run-server.sh b/run-server.sh index 968974cc37079..57f134336a46c 100644 --- a/run-server.sh +++ b/run-server.sh @@ -8,4 +8,4 @@ cd ../llama.cpp/models wget https://huggingface.co/ikawrakow/mixtral-instruct-8x7b-quantized-gguf/resolve/main/mixtral-instruct-8x7b-q2k.gguf mkdir ../build && cmake -S .. -B ../build -DLLAMA_CUBLAS=ON cmake --build ../build --config Release -../build/bin/server -m mixtral-instruct-8x7b-q2k.gguf --port 8081 -c 512 --host 172.17.0.2 -ngl 32 --log-format text +../build/bin/server -m mixtral-instruct-8x7b-q2k.gguf --port 8081 -c 512 --host 172.17.0.5 -ngl 32 --log-format text From 653b2575ad3ef45d6f7aeb4b1de6cf28bb1eb5f6 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Fri, 8 Mar 2024 20:21:36 -0500 Subject: [PATCH 08/12] update test --- run-server.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/run-server.sh b/run-server.sh index 57f134336a46c..927540b17c6c7 100644 --- a/run-server.sh +++ b/run-server.sh @@ -5,7 +5,7 @@ wget https://github.com/Kitware/CMake/releases/download/v3.29.0-rc1/cmake-3.29.0 tar -xzvf cmake-3.29.0-rc1.tar.gz cd cmake-3.29.0-rc1 && ./bootstrap && make -j$(nproc) && make install cd ../llama.cpp/models -wget https://huggingface.co/ikawrakow/mixtral-instruct-8x7b-quantized-gguf/resolve/main/mixtral-instruct-8x7b-q2k.gguf +wget https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF/resolve/main/mixtral-8x7b-v0.1.Q6_K.gguf mkdir ../build && cmake -S .. -B ../build -DLLAMA_CUBLAS=ON cmake --build ../build --config Release -../build/bin/server -m mixtral-instruct-8x7b-q2k.gguf --port 8081 -c 512 --host 172.17.0.5 -ngl 32 --log-format text +../build/bin/server -m mixtral-8x7b-v0.1.Q6_K.gguf --port 8081 -c 2048 --host 172.17.0.2 -ngl 33 --log-format text From 9d3b57ee26dcc0984de227e468b9bd5ae6870469 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Thu, 14 Mar 2024 11:58:47 -0600 Subject: [PATCH 09/12] flash decoding load tile data in sram --- ggml-cuda.cu | 475 ++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 339 insertions(+), 136 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index b42f143df8845..d7960c1ccdb89 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -115,6 +115,8 @@ #include #include #include +#include +#include #if CUDART_VERSION < 11020 #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED @@ -7493,28 +7495,29 @@ static __global__ void flash_attn_ext_f16( #endif } -template +template __global__ void flash_attn_row( - const float* query, - const half* key /* reuse key buffer for partials result */, - const half* value, - const half* mask, - half* tmp, + const float* __restrict__ query, + const half * __restrict__ key /* reuse key buffer for partials result */, + const half * __restrict__ value, + const half * __restrict__ mask, + half * __restrict__ tmp, int kv_size, float scale, int head_stride, - int r_kv_heads) { + int r_kv_heads, + int warp_data_size, + int qkv_block_size) { #if __CUDA_ARCH__ >= CC_VOLTA const int lane_index = threadIdx.x; const int warp_index = threadIdx.y; - const int warp_data_size = (head_dim*kv_tensor + 2); - extern __shared__ char shmem[]; half2* squery2 = (half2*)shmem; // load query buffer half * squery = (half *)shmem; // probabilities buffer after online softmax - float* sscores = (float*)(shmem + head_dim*kv_tensor*sizeof(half)); // scores buffer after QK^T - float* warp_buffer = (float*)(shmem + head_dim*kv_tensor*sizeof(half) + (kv_block + 2)*sizeof(float) + (warp_index*warp_data_size*sizeof(float))); + half * softmax_lse = (half *)(shmem + head_dim*kv_tensor*sizeof(half)); + half * warp_buffer = (half *)(shmem + head_dim*kv_tensor*sizeof(half) + 2*sizeof(half) + (warp_index*warp_data_size*sizeof(half))); + const int HD2 = head_dim / 2; const int kv_head_offset = (blockIdx.y / r_kv_heads) * head_stride; @@ -7539,15 +7542,19 @@ __global__ void flash_attn_row( __syncthreads(); + const int kv_per_warp = kv_block / num_warps; + const int KPW2 = kv_per_warp/2; + const int kvi = warp_index*kv_per_warp; + const int KVI2 = kvi/2; + { // QK^T half16x16_a query_m; nvcuda::wmma::load_matrix_sync(query_m, squery, 16); half16x16_bT key_m; - float16x16_acc kq_m; + half16x16_acc kq_m; + half scale_ = __float2half(scale); - const int kv_per_warp = kv_block / num_warps; const int sum_diag = WMMA_K / kv_tensor; - const int kvi = warp_index*kv_per_warp; #pragma unroll for (int kv = 0; kv < kv_per_warp; kv += kv_tensor) { @@ -7558,15 +7565,16 @@ __global__ void flash_attn_row( // sum diagonal if (lane_index < kv_tensor) { - float seq = 0.0f; + half seq = __half2float(0.0f); const int seq_idx = kvi + kv + lane_index; #pragma unroll for (int d0 = 0; d0 < sum_diag; d0++) { const int diag_idx = d0 + lane_index * sum_diag; seq += warp_buffer[diag_idx*WMMA_M + diag_idx]; // sum diagonal } + // store sequence result - sscores[seq_idx] = seq*scale + __half2float(mask[blockIdx.x*kv_block + seq_idx]); // save as float for softmax + squery[seq_idx] = seq*scale_ + mask[blockIdx.x*kv_block + seq_idx]; // save as float for softmax } } @@ -7575,23 +7583,26 @@ __global__ void flash_attn_row( // perform online softmax { - const int kv_per_warp = kv_block / num_warps; - float M = -INFINITY; + half M = __float2half(-INFINITY); const int kvi = warp_index*kv_per_warp; - for (int kv = lane_index*kv_tensor; kv < kv_per_warp; kv += WARP_SIZE*kv_tensor) { - M = fmaxf(M, fmaxf(sscores[kvi + kv], sscores[kvi + kv + 1])); +#pragma unroll + for (int k0 = 0; k0 < kv_per_warp; k0 += WARP_SIZE) { + const int kv = k0 + lane_index; + if(kv >= kv_per_warp) { + break; + } + M = __hmax(M, squery[kvi + kv]); } M = warp_reduce_max(M); + half2 M2 = make_half2(M, M); + half2 S = make_half2(0.0, 0.0); - float S = 0.0f; - - if(M != -INFINITY) { - for (int kv = lane_index*kv_tensor; kv < kv_per_warp; kv += WARP_SIZE*kv_tensor) { - S += expf(sscores[kvi + kv] - M); - S += expf(sscores[kvi + kv + 1] - M); + if(__hisinf(M) != -1) { + for (int kv = lane_index; kv < KPW2; kv += WARP_SIZE) { + S += h2exp(squery2[KVI2 + kv] - M2); } } @@ -7599,53 +7610,47 @@ __global__ void flash_attn_row( if(lane_index == 0) { warp_buffer[0] = M; - warp_buffer[1] = S; - // printf("warp index: %d, M= %.4f, S= %.4f\n", warp_index, M, S); + warp_buffer[1] = S.x + S.y; } __syncthreads(); // reduce warps if(warp_index == 0 && lane_index == 0) { - float M0 = warp_buffer[0]; - float S0 = warp_buffer[1]; + half M0 = warp_buffer[0]; + half S0 = warp_buffer[1]; for(int w = 1; w < num_warps; w++) { - float M1 = warp_buffer[w * warp_data_size]; - float S1 = warp_buffer[w * warp_data_size + 1]; + half M1 = warp_buffer[w * warp_data_size]; + half S1 = warp_buffer[w * warp_data_size + 1]; - float M = fmaxf(M0, M1); + half M_ = __hmax(M0, M1); - float ms0 = expf(M0 - M); - float ms1 = expf(M1 - M); + half ms0 = hexp(M0 - M_); + half ms1 = hexp(M1 - M_); S0 = S0*ms0 + S1*ms1; - M0 = M; + M0 = M_; } // real softmax M and S for this block - sscores[kv_block] = M0; - sscores[kv_block + 1] = S0; + softmax_lse[0] = M0; + softmax_lse[1] = S0; } __syncthreads(); - // reuse shared memory padding - M = sscores[kv_block]; - - const int te_per_warp = tensor_elements / num_warps; - - const int si = warp_index*te_per_warp; + M = softmax_lse[0]; + M2 = make_half2(M, M); #pragma unroll - for (int t0 = 0; t0 < te_per_warp; t0 += WARP_SIZE) { - const int tei = t0 + lane_index; - if(tei >= te_per_warp) { + for (int k0 = 0; k0 < KPW2; k0 += WARP_SIZE) { + const int kv = k0 + lane_index; + if(kv >= KPW2) { break; } - - const int sq_offset = si + tei; - squery[sq_offset] = __float2half(expf(sscores[sq_offset % kv_block] - M)); + const int sq_offset = KVI2 + kv; + squery2[sq_offset] = h2exp(squery2[KVI2 + kv] - M2); } __syncthreads(); @@ -7655,16 +7660,12 @@ __global__ void flash_attn_row( half16x16_a qk_m; nvcuda::wmma::load_matrix_sync(qk_m, squery, 16); half16x16_bT value_m; - float16x16_acc qkv_m; - - // const int qkv_block_size = gridDim.x * head_dim + gridDim.x * 2; - // const int qkv_head_offset = kv_head_offset + (blockIdx.y % r_kv_heads) * qkv_block_size; + half16x16_acc qkv_m; - const int qkv_block_size = gridDim.x * head_dim + gridDim.x * 2; const int qkv_head_offset = blockIdx.y * qkv_block_size; #pragma unroll - for(int h0 = 0; h0 < head_dim; h0 += num_warps) { + for(int h0 = 0; h0 < head_dim; h0 += num_warps) { const int hi = h0 + warp_index; if(hi >= head_dim) { break; @@ -7680,19 +7681,238 @@ __global__ void flash_attn_row( // sum diagonal if (lane_index == 0) { - float hdim = 0.0f; + half hdim = __float2half(0.0f); for (int d = 0; d < WMMA_K; d++) { hdim += warp_buffer[d*WMMA_M + d]; // sum diagonal } - tmp[output_offset + blockIdx.x] = __float2half(hdim); + // assume the key has been processed by blocks launched per head + tmp[output_offset + blockIdx.x] = hdim; + } + } + + if(warp_index == 0 && lane_index == 0) { + tmp[qkv_head_offset + gridDim.x * head_dim + blockIdx.x*2] = softmax_lse[0]; // max of this kv block + tmp[qkv_head_offset + gridDim.x * head_dim + blockIdx.x*2 + 1] = softmax_lse[1]; // sum of this kv block + } + } +#else + NO_DEVICE_CODE; +#endif +} + + +// too much instructions and barriers +template // kv_block should be tensor elements +__global__ void flash_attn_row_fast( + const half* __restrict__ query, + const half* __restrict__ key, + const half* __restrict__ value, + const half* __restrict__ mask, + half* __restrict__ qkv_tmp, + const float scale, const int kv_size, const int qkv_partial_size, int r_kv_heads, int kv_stride, int head_stride) { +#if __CUDA_ARCH__ >= CC_VOLTA + const int lane_index = threadIdx.x; + const int warp_index = threadIdx.y; + + cooperative_groups::thread_block blk = cooperative_groups::this_thread_block(); + + extern __shared__ char shmem[]; + half * sh_query = (half *)shmem; // Query half[kv_block] + half2* sh_query2 = (half2*)shmem; + half * warp_buffer = (half *)(shmem + kv_block*sizeof(half) + warp_index*kv_block*sizeof(half)); // warp_buffer float[256] + half * sh_kv_buffer = (half *)(shmem + kv_block*sizeof(half) + num_warps*kv_block*sizeof(half)); // Key half[kv_block][head_dim] | Value half[head_dim][kv_block] + half * sh_softmax = (half *)(shmem + kv_block*sizeof(half) + num_warps*kv_block*sizeof(half) + head_dim*kv_block*sizeof(half)); // softmax half[num_warps*2] + + // load query in shared memory + const int query_per_tensor = kv_block / head_dim; + const int kv_head_index = blockIdx.y / r_kv_heads; + const int HD2 = head_dim / 2; + +#pragma unroll + for (int qo = 0; qo < query_per_tensor; qo++) { + if(qo == 0) { // first read from global memory + cooperative_groups::memcpy_async(blk, sh_query, query + blockIdx.y*head_dim, sizeof(half) * head_dim); + cooperative_groups::wait(blk); + } else { // copy from shared memory + for(int i = (threadIdx.y*WARP_SIZE + threadIdx.x); i < HD2; i += num_warps*WARP_SIZE) { + sh_query2[qo*HD2 + i] = sh_query2[i]; + } + } + } + __syncthreads(); + + // load key in shared memory + { + for (int kv = 0; kv < kv_block; kv += query_per_tensor) { + for (int qo = 0; qo < query_per_tensor; qo++) { + const int key_index = blockIdx.x*kv_block + kv + qo; + cooperative_groups::memcpy_async(blk, + sh_kv_buffer + key_index*head_dim, + key + key_index*kv_stride + kv_head_index*head_stride, sizeof(half) * head_dim); + } + } + cooperative_groups::wait(blk); + } + + const int kv_per_warp = kv_block / num_warps; + const int kvi = warp_index*kv_per_warp; + const int KPW2 = kv_per_warp/2; + const int KVI2 = kvi/2; + + // perform QK^T*scale + mask and get max for softmax + { + half16x16_a qm; + nvcuda::wmma::load_matrix_sync(qm, sh_query, 16); + half16x16_bT km; + half16x16_acc kqm; + half M = __float2half(-INFINITE); + half scale_ = __float2half(scale); + + const int num_diag = WMMA_K / query_per_tensor; + // half* warp_tmp_buffer = sh_kv_buffer + kvi*head_dim; // save results from tensor cores + + for (int kv = 0; kv < kv_per_warp; kv += query_per_tensor) { + nvcuda::wmma::load_matrix_sync(km, sh_kv_buffer + (kvi + kv)*head_dim, 16); + nvcuda::wmma::fill_fragment(kqm, 0.0f); + nvcuda::wmma::mma_sync(kqm, qm, km, kqm); + nvcuda::wmma::store_matrix_sync(warp_buffer, kqm, 16, nvcuda::wmma::mem_row_major); + + // sum diagonal + if (lane_index < query_per_tensor) { + // TODO: make this half type + half seq = __half2float(0.0f); + const int seq_idx = kvi + kv + lane_index; +#pragma unroll + for (int d0 = 0; d0 < num_diag; d0++) { + const int diag_idx = d0 + lane_index * num_diag; + seq += warp_buffer[diag_idx*WMMA_M + diag_idx]; // sum diagonal + } + + seq = seq*scale_ + mask[blockIdx.x*kv_block + seq_idx]; + + // store sequence result + sh_query[seq_idx] = seq; // save as float for softmax + M = __hmax(M, seq); + } + } + + M = warp_reduce_max(M); + if(lane_index == 0) { + sh_softmax[warp_index*2] = M; + } + } + __syncthreads(); + + { + half2 S = make_half2(0.0, 0.0); + half M = sh_softmax[warp_index*2]; + + if(__hisinf(M) != -1) { + half2 M2 = make_half2(M, M); + for (int kv = lane_index; kv < KPW2; kv += WARP_SIZE) { + S += h2exp(sh_query2[KVI2 + kv] - M2); + } + } + + S = warp_reduce_sum(S); + + if(lane_index == 0) { + sh_softmax[warp_index*2 + 1] = S.x + S.y; + } + __syncthreads(); + } + + if(warp_index == 0 && lane_index == 0) { + half M0 = sh_softmax[0]; + half S0 = sh_softmax[1]; + + for(int w = 1; w < num_warps; w++) { + half M1 = sh_softmax[w*2]; + half S1 = sh_softmax[w*2 + 1]; + + half M = __hmax(M0, M1); + + half ms0 = hexp(M0 - M); + half ms1 = hexp(M1 - M); + + S0 = S0*ms0 + S1*ms1; + M0 = M; + } + + // real softmax M and S for this block + sh_softmax[0] = M0; + sh_softmax[1] = S0; + } + __syncthreads(); + + { + half M = sh_softmax[0]; + half2 M2 = make_half2(M, M); +#pragma unroll + for (int k0 = 0; k0 < KPW2; k0 += WARP_SIZE) { + const int kv = k0 + lane_index; + if(kv >= kv_per_warp) { + break; + } + const int kv_offset = KVI2 + kv; + sh_query2[kv_offset] = h2exp(sh_query2[kv_offset] - M2); + } + __syncthreads(); + } + + // load values in shared memory (no contiguous) (no coalesing acceses!!) + // for (int kv = warp_index; kv < kv_block; kv += num_warps) { + // const int kv_offset = (blockIdx.x*kv_block + kv)*kv_stride + kv_head_index*head_stride; + // for (int hdim = lane_index; hdim < head_dim; hdim += WARP_SIZE) { + // sh_kv_buffer[hdim*kv_block + kv] = value[kv_offset + hdim]; + // } + // } + + // coalesing shared and global access (requires value transposed and contigous) + for (int hdim = 0; hdim < head_dim; hdim ++) { + cooperative_groups::memcpy_async(blk, + sh_kv_buffer + hdim*kv_block, + value + (blockIdx.x*kv_block + hdim * kv_size + kv_head_index*kv_size*head_dim), sizeof(half) * kv_block); + } + cooperative_groups::wait(blk); + + // perform softmax(QK^T)V + { + half16x16_a kqm; + nvcuda::wmma::load_matrix_sync(kqm, sh_query, 16); + half16x16_bT vm; + half16x16_acc qkvm; + + const int qkv_head_offset = blockIdx.y * qkv_partial_size; + +#pragma unroll + for(int h0 = 0; h0 < head_dim; h0 += num_warps) { + const int hi = h0 + warp_index; + if(hi >= head_dim) { + break; + } + + const int output_offset = qkv_head_offset + hi * gridDim.x; + + nvcuda::wmma::load_matrix_sync(vm, sh_kv_buffer + hi * kv_block, 16); + nvcuda::wmma::fill_fragment(qkvm, 0.0f); + nvcuda::wmma::mma_sync(qkvm, kqm, vm, qkvm); + nvcuda::wmma::store_matrix_sync(warp_buffer, qkvm, 16, nvcuda::wmma::mem_row_major); + + if (lane_index == 0) { + half hdim = __float2half(0.0f); + for (int d = 0; d < WMMA_K; d++) { + hdim += warp_buffer[d*WMMA_M + d]; // sum diagonal + } + qkv_tmp[output_offset + blockIdx.x] = hdim; } } if(warp_index == 0 && lane_index == 0) { - tmp[qkv_head_offset + gridDim.x * head_dim + blockIdx.x*2] = __float2half(sscores[kv_block]); // max of this kv block - tmp[qkv_head_offset + gridDim.x * head_dim + blockIdx.x*2 + 1] = __float2half(sscores[kv_block + 1]); // sum of this kv block + qkv_tmp[qkv_head_offset + gridDim.x * head_dim + blockIdx.x*2] = sh_softmax[0]; // max of this kv block + qkv_tmp[qkv_head_offset + gridDim.x * head_dim + blockIdx.x*2 + 1] = sh_softmax[1]; // sum of this kv block } } #else @@ -7702,54 +7922,52 @@ __global__ void flash_attn_row( template __global__ void fa_reduce(const half* partial_qkv, float* qkv, int kv_size, int num_kv_blocks, int r_kv_heads) { +#if __CUDA_ARCH__ >= CC_VOLTA const int lane_index = threadIdx.x; const int warp_index = threadIdx.y; const int qkv_partial_offset = blockIdx.x * (num_kv_blocks * head_dim + num_kv_blocks*2); extern __shared__ char shmem[]; - float* softmax_lse = (float *)shmem; + half* softmax_lse = (half *)shmem; if(warp_index == 0 && lane_index == 0) { const int softmax_lse_offset = qkv_partial_offset + num_kv_blocks * head_dim; - float M0 = __half2float(partial_qkv[softmax_lse_offset]); - float S0 = __half2float(partial_qkv[softmax_lse_offset + 1]); + half M0 = partial_qkv[softmax_lse_offset]; + half S0 = partial_qkv[softmax_lse_offset + 1]; for(int i = 1; i < num_kv_blocks; i++) { - float M1 = __half2float(partial_qkv[softmax_lse_offset + i*2]); - float S1 = __half2float(partial_qkv[softmax_lse_offset + i*2 + 1]); - - float M = fmaxf(M0, M1); - - float ms0 = expf(M0 - M); - float ms1 = expf(M1 - M); - + half M1 = partial_qkv[softmax_lse_offset + i*2]; + half S1 = partial_qkv[softmax_lse_offset + i*2 + 1]; + half M = __hmax(M0, M1); + half ms0 = hexp(M0 - M); + half ms1 = hexp(M1 - M); S0 = S0*ms0 + S1*ms1; M0 = M; - softmax_lse[i*2 ] = ms0; softmax_lse[i*2 + 1] = ms1; } softmax_lse[0] = S0; } - __syncthreads(); const int hd_per_warp = head_dim / num_warps; - // reduce kv blocks (very slow!!) for(int hi = warp_index*hd_per_warp; hi < head_dim; hi += num_warps*hd_per_warp) { for(int hdi = lane_index; hdi < hd_per_warp; hdi += WARP_SIZE) { const int hdim_index = hi + hdi; const int qkv_index = qkv_partial_offset + hdim_index * num_kv_blocks; - float hdim = __half2float(partial_qkv[qkv_index]); + half hdim = partial_qkv[qkv_index]; for(int kv = 1; kv < num_kv_blocks; kv++) { - hdim = hdim * softmax_lse[kv*2] + __half2float(partial_qkv[qkv_index + kv]) * softmax_lse[kv * 2 + 1]; + hdim = hdim * softmax_lse[kv*2] +partial_qkv[qkv_index + kv] * softmax_lse[kv * 2 + 1]; } - qkv[blockIdx.x * head_dim + hdim_index] = hdim / softmax_lse[0]; + qkv[blockIdx.x * head_dim + hdim_index] = __half2float(hdim / softmax_lse[0]); } } +#else + NO_DEVICE_CODE; +#endif } template @@ -11944,20 +12162,20 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s static void save_tensor_to_file(const char* filename, const ggml_tensor* tensor, const char* name) { - // FILE* f = fopen(filename, "wb"); - // int len = strlen(name); - // int n_dims = ggml_n_dims(tensor); - // printf("writing '%s' - %d dimens\n", name, n_dims); - // fwrite(&n_dims, sizeof(n_dims), 1, f); + FILE* f = fopen(filename, "wb"); + int len = strlen(name); + int n_dims = ggml_n_dims(tensor); + printf("writing '%s' - %d dimens\n", name, n_dims); + fwrite(&n_dims, sizeof(n_dims), 1, f); printf("============== %s [%d, %d, %d, %d] =================\n", name, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); int ttype = (int)tensor->type; - // fwrite(&ttype, sizeof(ttype), 1, f); - // for (int i = 0; i < n_dims; ++i) { - // int ne_ = (int) tensor->ne[i]; - // fwrite(&ne_, sizeof(ne_), 1, f); - // } - // fwrite(&len, sizeof(len), 1, f); - // fwrite(name, len, 1, f); + fwrite(&ttype, sizeof(ttype), 1, f); + for (int i = 0; i < n_dims; ++i) { + int ne_ = (int) tensor->ne[i]; + fwrite(&ne_, sizeof(ne_), 1, f); + } + fwrite(&len, sizeof(len), 1, f); + fwrite(name, len, 1, f); void* data = malloc(ggml_nbytes(tensor)); ggml_backend_tensor_get(tensor, data, 0, ggml_nbytes(tensor)); // printf("[%d, %d] %zu\n", tensor->ne[0], tensor->ne[1], ggml_nbytes(tensor)); @@ -11971,12 +12189,9 @@ static void save_tensor_to_file(const char* filename, const ggml_tensor* tensor, } printf("\n"); } - // if(tensor->ne[0] == 128 && tensor->ne[1] == 32) { - // printf("BACKTRACKING: %.4f\n", ((float*)data)[113]); - // } - // fwrite(data, ggml_nbytes(tensor), 1, f); + fwrite(data, ggml_nbytes(tensor), 1, f); free(data); - // fclose(f); + fclose(f); } bool debug_kernel = true, debug_prompt = true; @@ -12137,32 +12352,14 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * src0, const ggml_tensor default: break; } - if(ne01 == 1 && debug_kernel) { - printf("TOKEN GENERATION\n"); - save_tensor_to_file("C:\\proyectos\\kernel-data\\tg\\fa-cuda-q-256.tensor", src0, "Query data"); - save_tensor_to_file("C:\\proyectos\\kernel-data\\tg\\fa-cuda-k-256.tensor", src1, "Key data"); - save_tensor_to_file("C:\\proyectos\\kernel-data\\tg\\fa-cuda-v-256.tensor", src2, "Value data"); - save_tensor_to_file("C:\\proyectos\\kernel-data\\tg\\fa-cuda-mask-256.tensor", src3, "Mask data"); - save_tensor_to_file("C:\\proyectos\\kernel-data\\tg\\fa-cuda-qkv-256.tensor", dst, "QKV data"); - debug_kernel = false; - } else if(ne01 == 104 && debug_prompt) { - printf("PROMPT PROCESSING\n"); - save_tensor_to_file("C:\\proyectos\\kernel-data\\tg\\fa-cuda-q-256.tensor", src0, "Query data"); - save_tensor_to_file("C:\\proyectos\\kernel-data\\tg\\fa-cuda-k-256.tensor", src1, "Key data"); - save_tensor_to_file("C:\\proyectos\\kernel-data\\tg\\fa-cuda-v-256.tensor", src2, "Value data"); - save_tensor_to_file("C:\\proyectos\\kernel-data\\tg\\fa-cuda-mask-256.tensor", src3, "Mask data"); - save_tensor_to_file("C:\\proyectos\\kernel-data\\tg\\fa-cuda-qkv-256.tensor", dst, "QKV data"); - debug_prompt = false; - } #ifdef GGML_FLASH_DECODING } else { #define WMMA_M 16 #define WMMA_N 16 #define WMMA_K 16 -#define TENSOR_ELEMENTS 256 #define KV_BLOCK_SIZE 256 - constexpr int num_warps = 1; + constexpr int num_warps = 8; constexpr int kv_per_block = KV_BLOCK_SIZE; int num_kv_blocks = ne11 / kv_per_block; @@ -12174,29 +12371,35 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * src0, const ggml_tensor const int r_kv_heads = ne02 / ne12; - int shmem = - ne00*2*sizeof(half) /* query buffer */ + - (kv_per_block + 2)*sizeof(float) /* scores buffer */ + - num_warps * (TENSOR_ELEMENTS + 2) * sizeof(float) /* tensor core result buffer per warp */; - - if(ne01 == 1 && debug_kernel) { - save_tensor_to_file("C:\\proyectos\\kernel-data\\tg\\fa-cuda-k-256.tensor", src1, "Key data"); - } - - flash_attn_row<128, num_warps, 2, kv_per_block, TENSOR_ELEMENTS, WMMA_M, WMMA_N, WMMA_K><<>>( - (const float*)src0_extra->data_device[g_main_device], - (const half*)src1_extra->data_device[g_main_device], - (const half*)src2_extra->data_device[g_main_device], - (const half*)src3_extra->data_device[g_main_device], - tmp, ne11, scale, ne10*ne11, r_kv_heads); - fa_reduce<128, num_warps><<>>(tmp, (float *)dst_extra->data_device[g_main_device], ne11, num_kv_blocks, r_kv_heads); - if(ne01 == 1 && debug_kernel) { - printf("TOKEN GENERATION FLASH DECODING\n"); - save_tensor_to_file("C:\\proyectos\\kernel-data\\tg\\fa-cuda-q-256.tensor", src0, "Query data"); - save_tensor_to_file("C:\\proyectos\\kernel-data\\tg\\fa-cuda-v-256.tensor", src2, "Value data"); - save_tensor_to_file("C:\\proyectos\\kernel-data\\tg\\fa-cuda-mask-256.tensor", src3, "Mask data"); - save_tensor_to_file("C:\\proyectos\\kernel-data\\tg\\fa-cuda-qkv-256.tensor", dst, "QKV data"); - debug_kernel = false; + bool flash_decoding_sram = false; + + if(!flash_decoding_sram) { + int shmem = + ne00*2*sizeof(half) /* query buffer */ + + 2*sizeof(half) /* scores buffer */ + + num_warps * (KV_BLOCK_SIZE + 2) * sizeof(half) /* tensor core result buffer per warp */; + flash_attn_row<128, num_warps, 2, kv_per_block, WMMA_M, WMMA_N, WMMA_K><<>>( + (const float*)src0_extra->data_device[g_main_device], + (const half *)src1_extra->data_device[g_main_device], + (const half *)src2_extra->data_device[g_main_device], + (const half *)src3_extra->data_device[g_main_device], + tmp, ne11, scale, ne10*ne11, r_kv_heads, KV_BLOCK_SIZE + 2, (num_kv_blocks * ne00) + num_kv_blocks*2); + fa_reduce<128, num_warps><<>>(tmp, (float *)dst_extra->data_device[g_main_device], ne11, num_kv_blocks, r_kv_heads); + } else { + int shmem = KV_BLOCK_SIZE*sizeof(half) + // query + kv_per_block*ne00*sizeof(half) + // kv size + num_warps*2*sizeof(half) + // softmax + num_warps*KV_BLOCK_SIZE*sizeof(half); // query + int shmem_red = num_kv_blocks*2*sizeof(float); + cudaFuncSetAttribute(flash_attn_row_fast<128, num_warps, kv_per_block, WMMA_M, WMMA_N, WMMA_K>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem); + + // requires key not contigous, value should be contigous + + // flash_attn_row_fast<128, num_warps, kv_per_block, WMMA_M, WMMA_N, WMMA_K><<>>( + // src0_f16_alloc.get(), (const half *)src1_extra->data_device[g_main_device], + // (const half *)src2_extra->data_device[g_main_device], + // (const half *)src3_extra->data_device[g_main_device], tmp, ne11, scale, ne10*ne11, r_kv_heads, nb11/sizeof(half), nb12/sizeof(half)); + // fa_reduce<128, num_warps><<>>(tmp, (float *)dst_extra->data_device[g_main_device], ne11, num_kv_blocks, r_kv_heads); } } #endif From 7c28e0f2129e0966549e3a815dd0bd426dc5b305 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Sat, 23 Mar 2024 11:46:40 -0600 Subject: [PATCH 10/12] flash original --- CMakeLists.txt | 12 +++++++-- examples/main/CMakeLists.txt | 6 +++++ ggml-cuda.cu | 48 +++++++++++++++++++++++++++++++++++- llama.cpp | 2 +- 4 files changed, 64 insertions(+), 4 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 48880f7204bf5..37f2f19bcc4a2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,6 +2,8 @@ cmake_minimum_required(VERSION 3.14) # for add_link_options and implicit target project("llama.cpp" C CXX) include(CheckIncludeFileCXX) +set(FLASH_DIR ../flash-attention-cpp) + set(CMAKE_EXPORT_COMPILE_COMMANDS ON) if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE) @@ -351,6 +353,7 @@ if (LLAMA_CUBLAS) find_package(CUDAToolkit) if (CUDAToolkit_FOUND) + link_directories(${FLASH_DIR}/build/Release) message(STATUS "cuBLAS found") enable_language(CUDA) @@ -358,6 +361,11 @@ if (LLAMA_CUBLAS) set(GGML_HEADERS_CUDA ggml-cuda.h) set(GGML_SOURCES_CUDA ggml-cuda.cu) + set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} + ${FLASH_DIR} + ${FLASH_DIR}/fa + ${FLASH_DIR}/cutlass/include) + add_compile_definitions(GGML_USE_CUBLAS) if (LLAMA_CUDA_FORCE_DMMV) add_compile_definitions(GGML_CUDA_FORCE_DMMV) @@ -379,12 +387,12 @@ if (LLAMA_CUBLAS) if (LLAMA_STATIC) if (WIN32) # As of 12.3.1 CUDA Tookit for Windows does not offer a static cublas library - set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt) + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt flash_attn) else () set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) endif() else() - set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt) + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt flash_attn) endif() set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cuda_driver) diff --git a/examples/main/CMakeLists.txt b/examples/main/CMakeLists.txt index d532980b76da8..16655fe468a83 100644 --- a/examples/main/CMakeLists.txt +++ b/examples/main/CMakeLists.txt @@ -3,3 +3,9 @@ add_executable(${TARGET} main.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_11) + +if (WIN32) + add_custom_command(TARGET main POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_SOURCE_DIR}/${FLASH_DIR}/build/Release/flash_attn.dll $) +else() + add_custom_command(TARGET main POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_SOURCE_DIR}/${FLASH_DIR}/build/libflash_attn.so $) +endif() diff --git a/ggml-cuda.cu b/ggml-cuda.cu index d7960c1ccdb89..dac3d60b724b9 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -1,7 +1,7 @@ #include "ggml-cuda.h" #include "ggml.h" #include "ggml-backend-impl.h" -#define GGML_FLASH_DECODING +//#define GGML_FLASH_DECODING #include #include #include @@ -117,6 +117,7 @@ #include #include #include +#include "fa_api.h" #if CUDART_VERSION < 11020 #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED @@ -12194,6 +12195,22 @@ static void save_tensor_to_file(const char* filename, const ggml_tensor* tensor, fclose(f); } +__global__ void flash_ext_f32_f16(float* src, half* dst, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if(idx >= n) { + return; + } + dst[idx] = __float2half(src[idx]); +} + +__global__ void flash_ext_f16_f32(half* src, float* dst, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if(idx >= n) { + return; + } + dst[idx] = __half2float(src[idx]); +} + bool debug_kernel = true, debug_prompt = true; inline void ggml_cuda_flash_attn_ext(const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) { @@ -12234,6 +12251,34 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * src0, const ggml_tensor #ifdef GGML_FLASH_DECODING if(ne00 != 128 || ne01 > 1) { #endif + if(ne00 == 128 || ne00 == 64) { // + float* d_softmax_lse; + half* d_query, *d_qkv; + + cudaMallocAsync((void **)&d_softmax_lse, ne02 * ne01 * sizeof(float), main_stream); + cudaMallocAsync((void **)&d_query, ggml_nelements(src0) * sizeof(half), main_stream); + cudaMallocAsync((void **)&d_qkv, ggml_nelements(dst) * sizeof(half), main_stream); + + // convert query to half + int num_blocks = (ggml_nelements(src0) + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; + flash_ext_f32_f16<<>>((float*)src0_extra->data_device[g_main_device], d_query, ggml_nelements(src0)); + + flash_attn_fwd( + d_query, + src1_extra->data_device[g_main_device], + src2_extra->data_device[g_main_device], + src3 ? src3_extra->data_device[g_main_device] : nullptr, d_qkv, d_softmax_lse, + ne00, ne01, ne11, ne02, ne12, 1, ne03, scale, main_stream); + cudaFreeAsync(d_softmax_lse, main_stream); + + // convert output from f16 to f32 + num_blocks = (ggml_nelements(dst) + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; + flash_ext_f16_f32<<>>(d_qkv, (float*)dst_extra->data_device[g_main_device], ggml_nelements(dst)); + cudaFreeAsync(d_query, main_stream); + cudaFreeAsync(d_qkv, main_stream); + return; + } + const int nqpb = NQPB; // queries per block const int ncpw = NCPW; // cache values per warp (does not work for other values) @@ -12352,6 +12397,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * src0, const ggml_tensor default: break; } + #ifdef GGML_FLASH_DECODING } else { #define WMMA_M 16 diff --git a/llama.cpp b/llama.cpp index 2bec80573a915..54a8e68a5ff6f 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1,6 +1,6 @@ #define LLAMA_API_INTERNAL #include "llama.h" -#define GGML_FLASH_DECODING +//#define GGML_FLASH_DECODING #include "unicode.h" #include "ggml.h" From 0a13dfec530cc19fe9173e74cbe6d748facba270 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Sat, 23 Mar 2024 18:51:14 -0600 Subject: [PATCH 11/12] fix linux library dir --- CMakeLists.txt | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 37f2f19bcc4a2..08b8dca721079 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -353,7 +353,12 @@ if (LLAMA_CUBLAS) find_package(CUDAToolkit) if (CUDAToolkit_FOUND) - link_directories(${FLASH_DIR}/build/Release) + if (WIN32) + link_directories(${FLASH_DIR}/build/Release) + else() + link_directories(${FLASH_DIR}/build) + endif() + message(STATUS "cuBLAS found") enable_language(CUDA) From 19775b08ca99e22c036e7f81ef6dc041f9f5da8a Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Sun, 24 Mar 2024 13:07:40 -0600 Subject: [PATCH 12/12] fix block size conversion --- ggml-cuda.cu | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index dac3d60b724b9..7cbc12479dd30 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -625,6 +625,7 @@ static_assert(sizeof(block_iq4_xs) == sizeof(ggml_fp16_t) + sizeof(uint16_t) + Q #define CUDA_ACC_BLOCK_SIZE 256 #define CUDA_IM2COL_BLOCK_SIZE 256 #define CUDA_POOL2D_BLOCK_SIZE 256 +#define CUDA_FA_CONVERT_BLOCK_SIZE 256 #define CUDA_Q8_0_NE_ALIGN 2048 @@ -12260,8 +12261,8 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * src0, const ggml_tensor cudaMallocAsync((void **)&d_qkv, ggml_nelements(dst) * sizeof(half), main_stream); // convert query to half - int num_blocks = (ggml_nelements(src0) + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; - flash_ext_f32_f16<<>>((float*)src0_extra->data_device[g_main_device], d_query, ggml_nelements(src0)); + int num_blocks = (ggml_nelements(src0) + CUDA_FA_CONVERT_BLOCK_SIZE - 1) / CUDA_FA_CONVERT_BLOCK_SIZE; + flash_ext_f32_f16<<>>((float*)src0_extra->data_device[g_main_device], d_query, ggml_nelements(src0)); flash_attn_fwd( d_query, @@ -12272,8 +12273,8 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * src0, const ggml_tensor cudaFreeAsync(d_softmax_lse, main_stream); // convert output from f16 to f32 - num_blocks = (ggml_nelements(dst) + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; - flash_ext_f16_f32<<>>(d_qkv, (float*)dst_extra->data_device[g_main_device], ggml_nelements(dst)); + num_blocks = (ggml_nelements(dst) + CUDA_FA_CONVERT_BLOCK_SIZE - 1) / CUDA_FA_CONVERT_BLOCK_SIZE; + flash_ext_f16_f32<<>>(d_qkv, (float*)dst_extra->data_device[g_main_device], ggml_nelements(dst)); cudaFreeAsync(d_query, main_stream); cudaFreeAsync(d_qkv, main_stream); return;