Skip to content

Commit

Permalink
Merge pull request #3 from ggerganov/flash-attn-cuda
Browse files Browse the repository at this point in the history
cuda : fix flash_attn kernel to produce same results as CPU
  • Loading branch information
FSSRepo authored Feb 1, 2024
2 parents fd878f7 + ac26f27 commit 43f7156
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 86 deletions.
194 changes: 113 additions & 81 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6455,6 +6455,8 @@ static __global__ void flash_attn_ext_f16(
half * sq = (half *) (__flash_attn_f16_shmem + 0*D); // holds the query data
half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); // same as above but in half2
half * ss = (half *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // scratch buffer for attention and diagonal matrix

half16x16_acc zr;
half16x16_acc lo[Q16][D16];

// load heads from Q to shared memory
Expand All @@ -6470,6 +6472,8 @@ static __global__ void flash_attn_ext_f16(
}
}

nvcuda::wmma::fill_fragment(zr, 0.0);

// zero out lo
for (int64_t j = 0; j < Q16; ++j) {
for (int64_t i = 0; i < D16; ++i) {
Expand All @@ -6487,12 +6491,12 @@ static __global__ void flash_attn_ext_f16(
__syncthreads();

{
float S[Q];
float M[Q];
half S[Q];
half M[Q];

for(int i = 0; i < Q; i++) {
S[i] = 0.0f;
M[i] = -INFINITY;
S[i] = __float2half(0.0f);
M[i] = __float2half(-INFINITY);
}

// assume K and V are same shape
Expand Down Expand Up @@ -6526,11 +6530,16 @@ static __global__ void flash_attn_ext_f16(
}
}

const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1;

// pointer to the mask
const half * mp = mask ? (const half *) (mask + iq1*nb31) : nullptr;

// prepare diagonal scale matrix
half16x16_b mscale;
for (int i = 0; i < 16; ++i) {
ss[i*T + i] = __float2half(scale);
}
nvcuda::wmma::load_matrix_sync(mscale, ss, T);

// loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns
for (int64_t ic = C*warp_id; ic < ne11; ic += C*num_warps) {
Expand All @@ -6555,111 +6564,129 @@ static __global__ void flash_attn_ext_f16(

// mqk = mqk*scale + mask
for (int64_t j = 0; j < Q16; ++j) {
for (uint32_t i = 0; i < mqk[j].num_elements; i++) {
// TODO: process mask
mqk[j].x[i] = __float2half(scale) * mqk[j].x[i];
}
half16x16_a mqka;
half16x16_acc mm;

// convert accumulator to matrix_a
nvcuda::wmma::store_matrix_sync( ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major);
nvcuda::wmma::load_matrix_sync (mqka, ss + 16*j*T + 16*cc, T);

nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major);
nvcuda::wmma::mma_sync(mqk[j], mqka, mscale, mm);
nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major);
}
}
}

// used to detect blocks full of -INF
float smax = -INFINITY;
half smax = __float2half(-INFINITY);

// online softmax
if (C == 32) {
for (int64_t j = 0; j < Q; ++j) {
const int64_t p = lane_id;

const float m = M[j];
const float s = __half2float(ss[j*T + p]);
const half m = M[j];
const half s = ss[j*T + p];

smax = warp_reduce_max(max(smax, s));
M[j] = warp_reduce_max(max(M[j], s));
smax = warp_reduce_max(__hmax(smax, s));
M[j] = warp_reduce_max(__hmax(M[j], s));

const float ms = m == -INFINITY ? 0.0f : expf(m - M[j]);
const float vs = s == -INFINITY ? 0.0f : expf(s - M[j]);
const half ms = __hisinf(m) ? __float2half(0.0f) : hexp(m - M[j]);
const half vs = __hisinf(s) ? __float2half(0.0f) : hexp(s - M[j]);

S[j] = S[j]*ms + warp_reduce_sum(vs);

// create a QxQ diagonal matrix for rescaling the output
if (p == j) {
ss[j*T + C + j] = __float2half(ms);
ss[j*T + C + j] = ms;
}

// the P matrix from the paper (Q rows, C columns)
ss[j*T + p] = __float2half(vs);
ss[j*T + p] = vs;
}
} else {
for (int64_t j = 0; j < Q; ++j) {
const float m = M[j];
const half m = M[j];

for (int64_t p = lane_id; p < C; p += NW) {
const float s = __half2float(ss[j*T + p]);
const half s = ss[j*T + p];

smax = warp_reduce_max(max(smax, s));
M[j] = warp_reduce_max(max(M[j], s));
smax = __hmax(smax, s);
M[j] = __hmax(M[j], s);
}

const float ms = m == -INFINITY ? 0.0f : expf(m - M[j]);
smax = warp_reduce_max(smax);
M[j] = warp_reduce_max(M[j]);

S[j] = S[j]*ms;
const half ms = __hisinf(m) ? __float2half(0.0f) : hexp(m - M[j]);

// create a QxQ diagonal matrix for rescaling the output
if (lane_id == j) {
ss[j*T + C + j] = __float2half(ms);
ss[j*T + C + j] = ms;
}

// local sum
half ls = 0.0f;

for (int64_t p = lane_id; p < C; p += NW) {
const float s = __half2float(ss[j*T + p]);
const half s = ss[j*T + p];

const float vs = s == -INFINITY ? 0.0f : expf(s - M[j]);
const half vs = __hisinf(s) ? __float2half(0.0f) : hexp(s - M[j]);

S[j] = S[j] + warp_reduce_sum(vs);
ls += vs;

// the P matrix from the paper (Q rows, C columns)
ss[j*T + p] = __float2half(vs);
ss[j*T + p] = vs;
}

S[j] = S[j]*ms + warp_reduce_sum(ls);
}
}

// skip -INF blocks
if (smax == -INFINITY) {
if (__hisinf(smax)) {
continue;
}

// O = diag(ms)*O
for (int64_t j = 0; j < Q16; ++j) {
// half16x16_a mm;
// half16x16_b zro;
half16x16_a mm;
half16x16_b lob;

// nvcuda::wmma::fill_fragment(zro, 0.0);
// nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T);
nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T);

for (int64_t i = 0; i < D16; ++i) {
//nvcuda::wmma::mma_sync(lo[j][i], mm, zro, lo[j][i]);
for (uint32_t k = 0; k < 16*16; k++) {
half tmp = ss[(16*j + k%16)*T + C + 16*j + k%16];
lo[j][i].x[k] = tmp * lo[j][i].x[k];
}
// convert accumulator to matrix_b
nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major);
nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T);

nvcuda::wmma::fill_fragment(lo[j][i], 0.0);
nvcuda::wmma::mma_sync(lo[j][i], mm, lob, lo[j][i]);
}

// restore zeros
nvcuda::wmma::store_matrix_sync(ss + 16*j*T + C + 16*j, zr, T, nvcuda::wmma::mem_row_major);
}

// O = O + (Q*K^T)*V
{
for (int cc = 0; cc < C/16; ++cc) {
const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23));

half16x16_b mk[D16];
for (int64_t i = 0; i < D16; ++i) {
half16x16_b mk;
nvcuda::wmma::load_matrix_sync(mk, pv + i*16, nb21/sizeof(half));
nvcuda::wmma::load_matrix_sync(mk[i], pv + i*16, nb21/sizeof(half));
}

for (int64_t j = 0; j < Q16; ++j) {
half16x16_a mv;
nvcuda::wmma::load_matrix_sync(mv, ss + 16*j*T + 16*cc, T);
half16x16_a mv[Q16];
for (int64_t j = 0; j < Q16; ++j) {
nvcuda::wmma::load_matrix_sync(mv[j], ss + 16*j*T + 16*cc, T);
}

nvcuda::wmma::mma_sync(lo[j][i], mv, mk, lo[j][i]);
for (int64_t j = 0; j < Q16; ++j) {
for (int64_t i = 0; i < D16; ++i) {
nvcuda::wmma::mma_sync(lo[j][i], mv[j], mk[i], lo[j][i]);
}
}
}
Expand All @@ -6669,16 +6696,16 @@ static __global__ void flash_attn_ext_f16(
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
for (int64_t j = 0; j < Q; ++j) {
if (lane_id == 0) {
ss[j*T + 0] = __float2half(S[j]);
ss[j*T + 1] = __float2half(M[j]);
ss[j*T + 0] = S[j];
ss[j*T + 1] = M[j];
}
}
}

// reduce the warps sequentially
for (int64_t sg = 1; sg < num_warps; ++sg) {
float S = 0.0f;
float M = -INFINITY;
half S = __float2half(0.0f);
half M = __float2half(-INFINITY);

__syncthreads();

Expand All @@ -6696,25 +6723,25 @@ static __global__ void flash_attn_ext_f16(
// the first simdgroup accumulates the results from the other simdgroups
if (warp_id == 0) {
for (int64_t j = 0; j < Q; ++j) {
const float S0 = __half2float(ss[j*T + 0]);
const float S1 = __half2float(ss[j*T + sg*SH + 0]);
const half S0 = ss[j*T + 0];
const half S1 = ss[j*T + sg*SH + 0];

const float M0 = __half2float(ss[j*T + 1]);
const float M1 = __half2float(ss[j*T + sg*SH + 1]);
const half M0 = ss[j*T + 1];
const half M1 = ss[j*T + sg*SH + 1];

M = max(M0, M1);
M = __hmax(M0, M1);

const float ms0 = M0 == -INFINITY ? 0.0f : expf(M0 - M);
const float ms1 = M1 == -INFINITY ? 0.0f : expf(M1 - M);
const half ms0 = __hisinf(M0) ? __float2half(0.0f) : hexp(M0 - M);
const half ms1 = __hisinf(M1) ? __float2half(0.0f) : hexp(M1 - M);

S = S0*ms0 + S1*ms1;

if (lane_id == 0) {
ss[j*T + 0] = __float2half(S);
ss[j*T + 1] = __float2half(M);
ss[j*T + 0] = S;
ss[j*T + 1] = M;

ss[j*T + C + j ] = __float2half(ms0);
ss[j*T + C + j + sg*SH] = __float2half(ms1);
ss[j*T + C + j ] = ms0;
ss[j*T + C + j + sg*SH] = ms1;
}
}

Expand All @@ -6732,10 +6759,11 @@ static __global__ void flash_attn_ext_f16(
nvcuda::wmma::fill_fragment(t2, 0.0);
nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T);
nvcuda::wmma::mma_sync(t2, ms1, t, t2);
// store temporally 'lo' data
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
// load 'lo' data into t
nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T);

// convert accumulator to matrix_b
nvcuda::wmma::store_matrix_sync( sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
nvcuda::wmma::load_matrix_sync (t, sq + 16*j*T + i*16, T);

nvcuda::wmma::mma_sync(lo[j][i], ms0, t, t2);
}
}
Expand All @@ -6751,15 +6779,13 @@ static __global__ void flash_attn_ext_f16(
}
}

// float2 * dst2 = (float2 *) dst;

// final rescale with 1/S and store to global memory
if (warp_id == 0) {
for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) {
const float S = __half2float(ss[j*T + 0]);
const half S = ss[j*T + 0];

for (int64_t i = lane_id; i < D; i += NW) {
dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i]) / S;
dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i] / S);
}
}
}
Expand Down Expand Up @@ -9618,7 +9644,7 @@ static void ggml_cuda_op_soft_max(

const int64_t ne00 = src0->ne[0];
const int64_t nrows_x = ggml_nrows(src0);
const int64_t nrows_y = src1 ? ggml_nrows(src1) : 1;
const int64_t nrows_y = src1 ? src0->ne[1] : 1; // note: using number of queries since mask can be padded!

float scale = 1.0f;
memcpy(&scale, dst->op_params, sizeof(float));
Expand Down Expand Up @@ -10897,8 +10923,8 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *

GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
GGML_ASSERT(!mask || mask->backend == GGML_BACKEND_GPU);
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 8) &&
"the Flash-Attention CUDA kernel requires the mask to be padded to 8 and at least n_queries big");
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");

ggml_cuda_set_device(g_main_device);
const cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
Expand All @@ -10912,19 +10938,25 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
float scale;
memcpy(&scale, KQV->op_params, sizeof(float));

const int nqpb = 16; // queries per block
const int ncpw = 32; // cache values per warp (does not work for other values)
// const int nwarps = Q->ne[1] <= nqpb ? MAX(4, MIN(K->ne[1]/ncpw, 32)) : 4;
const int nwarps = 1;
#define NQPB 16
#define NCPW 128

const int nqpb = NQPB; // queries per block
const int ncpw = NCPW; // cache values per warp (does not work for other values)

const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much?
// TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why
const int nwarps = Q->ne[1] <= nqpb ? MAX(2, MIN(K->ne[1]/ncpw, nwarps_max)) : 2;

dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]);
dim3 block_dim(32, nwarps, 1);

int shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2);
const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2);

switch (Q->ne[0])
{
case 16:
flash_attn_ext_f16<16, 16, 32>
flash_attn_ext_f16<16, NQPB, NCPW>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) src0_extra->data_device[g_main_device], // Query
(const char *) src1_extra->data_device[g_main_device], // Key
Expand All @@ -10941,7 +10973,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
);
break;
case 64:
flash_attn_ext_f16<64, 16, 32>
flash_attn_ext_f16<64, NQPB, NCPW>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) src0_extra->data_device[g_main_device], // Query
(const char *) src1_extra->data_device[g_main_device], // Key
Expand All @@ -10958,7 +10990,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
);
break;
case 80:
flash_attn_ext_f16<80, 16, 32>
flash_attn_ext_f16<80, NQPB, NCPW>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) src0_extra->data_device[g_main_device], // Query
(const char *) src1_extra->data_device[g_main_device], // Key
Expand All @@ -10975,7 +11007,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
);
break;
case 128:
flash_attn_ext_f16<128, 16, 32>
flash_attn_ext_f16<128, NQPB, NCPW>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) src0_extra->data_device[g_main_device], // Query
(const char *) src1_extra->data_device[g_main_device], // Key
Expand Down
Loading

0 comments on commit 43f7156

Please sign in to comment.