Skip to content

Commit

Permalink
optional flash-decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
FSSRepo committed Mar 7, 2024
1 parent f490812 commit eecf7ee
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 13 deletions.
11 changes: 6 additions & 5 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include "ggml-cuda.h"
#include "ggml.h"
#include "ggml-backend-impl.h"

#define GGML_FLASH_DECODING
#include <algorithm>
#include <assert.h>
#include <atomic>
Expand Down Expand Up @@ -7339,7 +7339,6 @@ static __global__ void flash_attn_ext_f16(
// create a QxQ diagonal matrix for rescaling the output
if (lane_id == j && !__hisnan(ms)) {
ss[j*T + C + j] = ms;

S = S*ms + ls.x + ls.y;
}
}
Expand Down Expand Up @@ -11980,9 +11979,9 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * src0, const ggml_tensor
#define NQPB 16
#define NCPW 128

bool flash_decoding = true;

if(!flash_decoding || ne00 != 128 || ne01 > 1) {
#ifdef GGML_FLASH_DECODING
if(ne00 != 128 || ne01 > 1) {
#endif
const int nqpb = NQPB; // queries per block
const int ncpw = NCPW; // cache values per warp (does not work for other values)

Expand Down Expand Up @@ -12101,6 +12100,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * src0, const ggml_tensor
default:
break;
}
#ifdef GGML_FLASH_DECODING
} else {
#define WMMA_M 16
#define WMMA_N 16
Expand Down Expand Up @@ -12130,6 +12130,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * src0, const ggml_tensor
(const half*)src1_extra->data_device[g_main_device],
(float *)dst_extra->data_device[g_main_device], ne11, ne11 / kv_per_block, reduce_block);
}
#endif

CUDA_CHECK(cudaGetLastError());
}
Expand Down
17 changes: 9 additions & 8 deletions llama.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#define LLAMA_API_INTERNAL
#include "llama.h"

#define GGML_FLASH_DECODING
#include "unicode.h"

#include "ggml.h"
Expand Down Expand Up @@ -5025,8 +5025,6 @@ static struct ggml_tensor * llm_build_kqv(

GGML_ASSERT(kq_pos == nullptr && "ALiBi is not yet supported with Flash Attention");

bool flash_decoding = true;

// split cached v into n_head heads (not transposed)
struct ggml_tensor * v =
ggml_view_3d(ctx, kv.v_l[il],
Expand All @@ -5035,11 +5033,14 @@ static struct ggml_tensor * llm_build_kqv(
ggml_row_size(kv.v_l[il]->type, n_embd_head_k),
0);
cb(v, "v", il);

cur = ggml_flash_attn_ext(ctx, q, ggml_cont(ctx, k), ggml_cont(ctx,
flash_decoding && n_tokens == 1 ?
ggml_permute(ctx, v, 1, 0, 2, 3) : v), ggml_cont(ctx, kq_mask), kq_scale);

#ifdef GGML_FLASH_DECODING
cur = ggml_flash_attn_ext(ctx, q,
n_tokens == 1 ? ggml_cont(ctx, k) : k,
n_tokens == 1 ? ggml_cont(ctx, ggml_permute(ctx, v, 1, 0, 2, 3)) : v,
ggml_cont(ctx, kq_mask), kq_scale);
#else
cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale);
#endif
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_DEFAULT);
//printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]);
//printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]);
Expand Down

0 comments on commit eecf7ee

Please sign in to comment.