Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cuda : fix flash_attn kernel to produce same results as CPU #3

Merged
merged 4 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 113 additions & 81 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6455,6 +6455,8 @@ static __global__ void flash_attn_ext_f16(
half * sq = (half *) (__flash_attn_f16_shmem + 0*D); // holds the query data
half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); // same as above but in half2
half * ss = (half *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // scratch buffer for attention and diagonal matrix

half16x16_acc zr;
half16x16_acc lo[Q16][D16];

// load heads from Q to shared memory
Expand All @@ -6470,6 +6472,8 @@ static __global__ void flash_attn_ext_f16(
}
}

nvcuda::wmma::fill_fragment(zr, 0.0);

// zero out lo
for (int64_t j = 0; j < Q16; ++j) {
for (int64_t i = 0; i < D16; ++i) {
Expand All @@ -6487,12 +6491,12 @@ static __global__ void flash_attn_ext_f16(
__syncthreads();

{
float S[Q];
float M[Q];
half S[Q];
half M[Q];

for(int i = 0; i < Q; i++) {
S[i] = 0.0f;
M[i] = -INFINITY;
S[i] = __float2half(0.0f);
M[i] = __float2half(-INFINITY);
}

// assume K and V are same shape
Expand Down Expand Up @@ -6526,11 +6530,16 @@ static __global__ void flash_attn_ext_f16(
}
}

const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1;

// pointer to the mask
const half * mp = mask ? (const half *) (mask + iq1*nb31) : nullptr;

// prepare diagonal scale matrix
half16x16_b mscale;
for (int i = 0; i < 16; ++i) {
ss[i*T + i] = __float2half(scale);
}
nvcuda::wmma::load_matrix_sync(mscale, ss, T);

// loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns
for (int64_t ic = C*warp_id; ic < ne11; ic += C*num_warps) {
Expand All @@ -6555,111 +6564,129 @@ static __global__ void flash_attn_ext_f16(

// mqk = mqk*scale + mask
for (int64_t j = 0; j < Q16; ++j) {
for (uint32_t i = 0; i < mqk[j].num_elements; i++) {
// TODO: process mask
mqk[j].x[i] = __float2half(scale) * mqk[j].x[i];
}
half16x16_a mqka;
half16x16_acc mm;

// convert accumulator to matrix_a
nvcuda::wmma::store_matrix_sync( ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major);
nvcuda::wmma::load_matrix_sync (mqka, ss + 16*j*T + 16*cc, T);

nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major);
nvcuda::wmma::mma_sync(mqk[j], mqka, mscale, mm);
nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major);
}
}
}

// used to detect blocks full of -INF
float smax = -INFINITY;
half smax = __float2half(-INFINITY);

// online softmax
if (C == 32) {
for (int64_t j = 0; j < Q; ++j) {
const int64_t p = lane_id;

const float m = M[j];
const float s = __half2float(ss[j*T + p]);
const half m = M[j];
const half s = ss[j*T + p];

smax = warp_reduce_max(max(smax, s));
M[j] = warp_reduce_max(max(M[j], s));
smax = warp_reduce_max(__hmax(smax, s));
M[j] = warp_reduce_max(__hmax(M[j], s));

const float ms = m == -INFINITY ? 0.0f : expf(m - M[j]);
const float vs = s == -INFINITY ? 0.0f : expf(s - M[j]);
const half ms = __hisinf(m) ? __float2half(0.0f) : hexp(m - M[j]);
const half vs = __hisinf(s) ? __float2half(0.0f) : hexp(s - M[j]);

S[j] = S[j]*ms + warp_reduce_sum(vs);

// create a QxQ diagonal matrix for rescaling the output
if (p == j) {
ss[j*T + C + j] = __float2half(ms);
ss[j*T + C + j] = ms;
}

// the P matrix from the paper (Q rows, C columns)
ss[j*T + p] = __float2half(vs);
ss[j*T + p] = vs;
}
} else {
for (int64_t j = 0; j < Q; ++j) {
const float m = M[j];
const half m = M[j];

for (int64_t p = lane_id; p < C; p += NW) {
const float s = __half2float(ss[j*T + p]);
const half s = ss[j*T + p];

smax = warp_reduce_max(max(smax, s));
M[j] = warp_reduce_max(max(M[j], s));
smax = __hmax(smax, s);
M[j] = __hmax(M[j], s);
}

const float ms = m == -INFINITY ? 0.0f : expf(m - M[j]);
smax = warp_reduce_max(smax);
M[j] = warp_reduce_max(M[j]);

S[j] = S[j]*ms;
const half ms = __hisinf(m) ? __float2half(0.0f) : hexp(m - M[j]);

// create a QxQ diagonal matrix for rescaling the output
if (lane_id == j) {
ss[j*T + C + j] = __float2half(ms);
ss[j*T + C + j] = ms;
}

// local sum
half ls = 0.0f;

for (int64_t p = lane_id; p < C; p += NW) {
const float s = __half2float(ss[j*T + p]);
const half s = ss[j*T + p];

const float vs = s == -INFINITY ? 0.0f : expf(s - M[j]);
const half vs = __hisinf(s) ? __float2half(0.0f) : hexp(s - M[j]);

S[j] = S[j] + warp_reduce_sum(vs);
ls += vs;

// the P matrix from the paper (Q rows, C columns)
ss[j*T + p] = __float2half(vs);
ss[j*T + p] = vs;
}

S[j] = S[j]*ms + warp_reduce_sum(ls);
}
}

// skip -INF blocks
if (smax == -INFINITY) {
if (__hisinf(smax)) {
continue;
}

// O = diag(ms)*O
for (int64_t j = 0; j < Q16; ++j) {
// half16x16_a mm;
// half16x16_b zro;
half16x16_a mm;
half16x16_b lob;

// nvcuda::wmma::fill_fragment(zro, 0.0);
// nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T);
nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T);

for (int64_t i = 0; i < D16; ++i) {
//nvcuda::wmma::mma_sync(lo[j][i], mm, zro, lo[j][i]);
for (uint32_t k = 0; k < 16*16; k++) {
half tmp = ss[(16*j + k%16)*T + C + 16*j + k%16];
lo[j][i].x[k] = tmp * lo[j][i].x[k];
}
// convert accumulator to matrix_b
nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major);
nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T);

nvcuda::wmma::fill_fragment(lo[j][i], 0.0);
nvcuda::wmma::mma_sync(lo[j][i], mm, lob, lo[j][i]);
}

// restore zeros
nvcuda::wmma::store_matrix_sync(ss + 16*j*T + C + 16*j, zr, T, nvcuda::wmma::mem_row_major);
}

// O = O + (Q*K^T)*V
{
for (int cc = 0; cc < C/16; ++cc) {
const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23));

half16x16_b mk[D16];
for (int64_t i = 0; i < D16; ++i) {
half16x16_b mk;
nvcuda::wmma::load_matrix_sync(mk, pv + i*16, nb21/sizeof(half));
nvcuda::wmma::load_matrix_sync(mk[i], pv + i*16, nb21/sizeof(half));
}

for (int64_t j = 0; j < Q16; ++j) {
half16x16_a mv;
nvcuda::wmma::load_matrix_sync(mv, ss + 16*j*T + 16*cc, T);
half16x16_a mv[Q16];
for (int64_t j = 0; j < Q16; ++j) {
nvcuda::wmma::load_matrix_sync(mv[j], ss + 16*j*T + 16*cc, T);
}

nvcuda::wmma::mma_sync(lo[j][i], mv, mk, lo[j][i]);
for (int64_t j = 0; j < Q16; ++j) {
for (int64_t i = 0; i < D16; ++i) {
nvcuda::wmma::mma_sync(lo[j][i], mv[j], mk[i], lo[j][i]);
}
}
}
Expand All @@ -6669,16 +6696,16 @@ static __global__ void flash_attn_ext_f16(
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
for (int64_t j = 0; j < Q; ++j) {
if (lane_id == 0) {
ss[j*T + 0] = __float2half(S[j]);
ss[j*T + 1] = __float2half(M[j]);
ss[j*T + 0] = S[j];
ss[j*T + 1] = M[j];
}
}
}

// reduce the warps sequentially
for (int64_t sg = 1; sg < num_warps; ++sg) {
float S = 0.0f;
float M = -INFINITY;
half S = __float2half(0.0f);
half M = __float2half(-INFINITY);

__syncthreads();

Expand All @@ -6696,25 +6723,25 @@ static __global__ void flash_attn_ext_f16(
// the first simdgroup accumulates the results from the other simdgroups
if (warp_id == 0) {
for (int64_t j = 0; j < Q; ++j) {
const float S0 = __half2float(ss[j*T + 0]);
const float S1 = __half2float(ss[j*T + sg*SH + 0]);
const half S0 = ss[j*T + 0];
const half S1 = ss[j*T + sg*SH + 0];

const float M0 = __half2float(ss[j*T + 1]);
const float M1 = __half2float(ss[j*T + sg*SH + 1]);
const half M0 = ss[j*T + 1];
const half M1 = ss[j*T + sg*SH + 1];

M = max(M0, M1);
M = __hmax(M0, M1);

const float ms0 = M0 == -INFINITY ? 0.0f : expf(M0 - M);
const float ms1 = M1 == -INFINITY ? 0.0f : expf(M1 - M);
const half ms0 = __hisinf(M0) ? __float2half(0.0f) : hexp(M0 - M);
const half ms1 = __hisinf(M1) ? __float2half(0.0f) : hexp(M1 - M);

S = S0*ms0 + S1*ms1;

if (lane_id == 0) {
ss[j*T + 0] = __float2half(S);
ss[j*T + 1] = __float2half(M);
ss[j*T + 0] = S;
ss[j*T + 1] = M;

ss[j*T + C + j ] = __float2half(ms0);
ss[j*T + C + j + sg*SH] = __float2half(ms1);
ss[j*T + C + j ] = ms0;
ss[j*T + C + j + sg*SH] = ms1;
}
}

Expand All @@ -6732,10 +6759,11 @@ static __global__ void flash_attn_ext_f16(
nvcuda::wmma::fill_fragment(t2, 0.0);
nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T);
nvcuda::wmma::mma_sync(t2, ms1, t, t2);
// store temporally 'lo' data
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
// load 'lo' data into t
nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T);

// convert accumulator to matrix_b
nvcuda::wmma::store_matrix_sync( sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
nvcuda::wmma::load_matrix_sync (t, sq + 16*j*T + i*16, T);

nvcuda::wmma::mma_sync(lo[j][i], ms0, t, t2);
}
}
Expand All @@ -6751,15 +6779,13 @@ static __global__ void flash_attn_ext_f16(
}
}

// float2 * dst2 = (float2 *) dst;

// final rescale with 1/S and store to global memory
if (warp_id == 0) {
for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) {
const float S = __half2float(ss[j*T + 0]);
const half S = ss[j*T + 0];

for (int64_t i = lane_id; i < D; i += NW) {
dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i]) / S;
dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i] / S);
}
}
}
Expand Down Expand Up @@ -9618,7 +9644,7 @@ static void ggml_cuda_op_soft_max(

const int64_t ne00 = src0->ne[0];
const int64_t nrows_x = ggml_nrows(src0);
const int64_t nrows_y = src1 ? ggml_nrows(src1) : 1;
const int64_t nrows_y = src1 ? src0->ne[1] : 1; // note: using number of queries since mask can be padded!

float scale = 1.0f;
memcpy(&scale, dst->op_params, sizeof(float));
Expand Down Expand Up @@ -10897,8 +10923,8 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *

GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
GGML_ASSERT(!mask || mask->backend == GGML_BACKEND_GPU);
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 8) &&
"the Flash-Attention CUDA kernel requires the mask to be padded to 8 and at least n_queries big");
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");

ggml_cuda_set_device(g_main_device);
const cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
Expand All @@ -10912,19 +10938,25 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
float scale;
memcpy(&scale, KQV->op_params, sizeof(float));

const int nqpb = 16; // queries per block
const int ncpw = 32; // cache values per warp (does not work for other values)
// const int nwarps = Q->ne[1] <= nqpb ? MAX(4, MIN(K->ne[1]/ncpw, 32)) : 4;
const int nwarps = 1;
#define NQPB 16
#define NCPW 128

const int nqpb = NQPB; // queries per block
const int ncpw = NCPW; // cache values per warp (does not work for other values)

const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much?
// TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why
const int nwarps = Q->ne[1] <= nqpb ? MAX(2, MIN(K->ne[1]/ncpw, nwarps_max)) : 2;

dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]);
dim3 block_dim(32, nwarps, 1);

int shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2);
const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2);

switch (Q->ne[0])
{
case 16:
flash_attn_ext_f16<16, 16, 32>
flash_attn_ext_f16<16, NQPB, NCPW>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) src0_extra->data_device[g_main_device], // Query
(const char *) src1_extra->data_device[g_main_device], // Key
Expand All @@ -10941,7 +10973,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
);
break;
case 64:
flash_attn_ext_f16<64, 16, 32>
flash_attn_ext_f16<64, NQPB, NCPW>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) src0_extra->data_device[g_main_device], // Query
(const char *) src1_extra->data_device[g_main_device], // Key
Expand All @@ -10958,7 +10990,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
);
break;
case 80:
flash_attn_ext_f16<80, 16, 32>
flash_attn_ext_f16<80, NQPB, NCPW>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) src0_extra->data_device[g_main_device], // Query
(const char *) src1_extra->data_device[g_main_device], // Key
Expand All @@ -10975,7 +11007,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
);
break;
case 128:
flash_attn_ext_f16<128, 16, 32>
flash_attn_ext_f16<128, NQPB, NCPW>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) src0_extra->data_device[g_main_device], // Query
(const char *) src1_extra->data_device[g_main_device], // Key
Expand Down
Loading