Skip to content

Commit

Permalink
ggml : ggml_flash_attn_ext() support ALiBi (Metal)
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed May 10, 2024
1 parent 166e60b commit 97c27f5
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 47 deletions.
76 changes: 44 additions & 32 deletions ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -1390,8 +1390,8 @@ static enum ggml_status ggml_metal_graph_compute(
const int64_t nrows_x = ggml_nrows(src0);
const int64_t nrows_y = src0->ne[1];

const uint32_t n_head_kv = nrows_x/nrows_y;
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
const uint32_t n_head = nrows_x/nrows_y;
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));

const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
Expand Down Expand Up @@ -2513,7 +2513,7 @@ static enum ggml_status ggml_metal_graph_compute(
"the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");

const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
const int64_t ne31 = src3 ? src3->ne[1] : 0;
//const int64_t ne31 = src3 ? src3->ne[1] : 0;
const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33);

Expand All @@ -2525,7 +2525,16 @@ static enum ggml_status ggml_metal_graph_compute(
const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);

float scale;
memcpy(&scale, dst->op_params, sizeof(float));
float max_bias;

memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));

const uint32_t n_head = src0->ne[2];
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));

const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);

id<MTLComputePipelineState> pipeline = nil;

Expand Down Expand Up @@ -2562,34 +2571,37 @@ static enum ggml_status ggml_metal_graph_compute(
}

[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7];
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8];
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11];
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12];
[encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13];
[encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14];
[encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15];
[encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16];
[encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17];
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18];
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19];
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20];
[encoder setBytes:&ne31 length:sizeof( int64_t) atIndex:21];
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:22];
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:23];
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:24];
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:25];
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7];
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8];
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11];
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12];
[encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13];
[encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14];
[encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15];
[encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16];
[encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17];
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18];
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19];
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20];
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:21];
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:22];
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:23];
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:24];
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:25];
[encoder setBytes:&scale length:sizeof( float) atIndex:26];
[encoder setBytes:&max_bias length:sizeof( float) atIndex:27];
[encoder setBytes:&m0 length:sizeof(m0) atIndex:28];
[encoder setBytes:&m1 length:sizeof(m1) atIndex:29];
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:30];

if (!use_vec_kernel) {
// half8x8 kernel
Expand Down
50 changes: 43 additions & 7 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -2058,13 +2058,16 @@ typedef void (flash_attn_ext_f16_t)(
constant uint64_t & nb11,
constant uint64_t & nb12,
constant uint64_t & nb13,
constant int64_t & ne31,
constant uint64_t & nb31,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant float & scale,
constant float & max_bias,
constant float & m0,
constant float & m1,
constant uint32_t & n_head_log2,
threadgroup half * shared,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
Expand Down Expand Up @@ -2096,13 +2099,16 @@ kernel void kernel_flash_attn_ext_f16(
constant uint64_t & nb11,
constant uint64_t & nb12,
constant uint64_t & nb13,
constant int64_t & ne31,
constant uint64_t & nb31,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant float & scale,
constant float & max_bias,
constant float & m0,
constant float & m1,
constant uint32_t & n_head_log2,
threadgroup half * shared [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
Expand Down Expand Up @@ -2199,6 +2205,19 @@ kernel void kernel_flash_attn_ext_f16(
// prepare diagonal scale matrix
simdgroup_float8x8 mscale(scale);

// prepare diagonal slope matrix
simdgroup_float8x8 mslope(1.0f);

// ALiBi
if (max_bias > 0.0f) {
const short h = iq2;

const float base = h < n_head_log2 ? m0 : m1;
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;

mslope = simdgroup_float8x8(pow(base, exph));
}

// loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns
for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
Expand All @@ -2221,9 +2240,10 @@ kernel void kernel_flash_attn_ext_f16(
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
}

// mqk = mqk*scale + mask
// mqk = mqk*scale + mask*slope
simdgroup_half8x8 mm;
simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
simdgroup_multiply(mm, mslope, mm);
simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);

simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
Expand Down Expand Up @@ -2414,13 +2434,16 @@ kernel void kernel_flash_attn_ext_vec_f16(
constant uint64_t & nb11,
constant uint64_t & nb12,
constant uint64_t & nb13,
constant int64_t & ne31,
constant uint64_t & nb31,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant float & scale,
constant float & max_bias,
constant float & m0,
constant float & m1,
constant uint32_t & n_head_log2,
threadgroup half * shared [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
Expand All @@ -2439,6 +2462,18 @@ kernel void kernel_flash_attn_ext_vec_f16(

const short T = D + 2*nsg*SH; // shared memory size per query in (half)

float slope = 1.0f;

// ALiBi
if (max_bias > 0.0f) {
const short h = iq2;

const float base = h < n_head_log2 ? m0 : m1;
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;

slope = pow(base, exp);
}

//threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
Expand Down Expand Up @@ -2545,10 +2580,10 @@ kernel void kernel_flash_attn_ext_vec_f16(
mqk += simd_shuffle_down(mqk, 2);
mqk += simd_shuffle_down(mqk, 1);

// mqk = mqk*scale + mask
// mqk = mqk*scale + mask*slope
if (tiisg == 0) {
float4 mm = (float4) mp4[ic/4 + cc];
mqk = mqk*scale + mm;
mqk = mqk*scale + mm*slope;

ss4[cc] = mqk;
}
Expand Down Expand Up @@ -2782,7 +2817,8 @@ kernel void kernel_cpy_f32_f16(
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);

dst_data[i00] = src[0];
// TODO: is there a better way to handle -INFINITY?
dst_data[i00] = src[0] == -INFINITY ? -MAXHALF : src[0];
}
}

Expand Down
6 changes: 6 additions & 0 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10985,6 +10985,12 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
}
}

for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
for (int j = 0; j < n_kv; ++j) {
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
}
}
}
} else {
// when using kv cache, the mask needs to match the kv cache size
Expand Down
20 changes: 12 additions & 8 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1486,23 +1486,25 @@ struct test_flash_attn_ext : public test_case {
const int64_t kv; // kv size
const int64_t nb; // batch size

const float max_bias; // ALiBi

std::string vars() override {
return VARS_TO_STR4(hs, nh, kv, nb);
return VARS_TO_STR5(hs, nh, kv, nb, max_bias);
}

double max_nmse_err() override {
return 5e-4;
}

test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8)
: hs(hs), nh(nh), kv(kv), nb(nb) {}
test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, float max_bias = 0.0f)
: hs(hs), nh(nh), kv(kv), nb(nb), max_bias(max_bias) {}

ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1);
ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1);
ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs));
ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs), max_bias);
return out;
}
};
Expand Down Expand Up @@ -2176,10 +2178,12 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
#else
for (int hs : { 64, 80, 128, 256, }) {
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
for (int nh : { 32, }) {
for (int kv : { 512, 1024, }) {
for (int nb : { 1, 2, 4, 8, }) {
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb));
for (float max_bias : {0.0f, 8.0f}) {
for (int nh : { 32, }) {
for (int kv : { 512, 1024, }) {
for (int nb : { 1, 2, 4, 8, }) {
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, max_bias));
}
}
}
}
Expand Down

0 comments on commit 97c27f5

Please sign in to comment.