Skip to content

Commit

Permalink
ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Jan 31, 2024
1 parent 2ddc9bb commit 8ad92dc
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 62 deletions.
20 changes: 10 additions & 10 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5917,7 +5917,7 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
}

template <bool vals_smem, int ncols_template, int block_size_template, bool need_check>
static __global__ void soft_max_f16(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
static __global__ void soft_max_f16(const float * x, const half * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template;
const int ncols_smem = GGML_PAD(ncols_data, 2*WARP_SIZE)/2;
Expand Down Expand Up @@ -5952,12 +5952,12 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds
if (need_check && col_data + 0 >= ncols_data) {
val.x = -INFINITY;
} else {
val.x = x[ix + 0]*scale + (y ? y[iy + 0] : 0.0f);
val.x = x[ix + 0]*scale + (y ? __half2float(y[iy + 0]) : 0.0f);
}
if (need_check && col_data + WARP_SIZE >= ncols_data) {
val.y = -INFINITY;
} else {
val.y = x[ix + WARP_SIZE]*scale + (y ? y[iy + WARP_SIZE] : 0.0f);
val.y = x[ix + WARP_SIZE]*scale + (y ? __half2float(y[iy + WARP_SIZE]) : 0.0f);
}
if (!need_check || col_smem < (vals_smem ? ncols_smem : ncols_data)) {
vals[col_smem] = val;
Expand Down Expand Up @@ -6047,7 +6047,7 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds
}

template <bool vals_smem, int ncols_template, int block_size_template>
static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
static __global__ void soft_max_f32(const float * x, const half * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;

const int tid = threadIdx.x;
Expand Down Expand Up @@ -6077,7 +6077,7 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds
const int ix = rowx*ncols + col;
const int iy = rowy*ncols + col;

const float val = x[ix]*scale + (y ? y[iy] : 0.0f);
const float val = x[ix]*scale + (y ? __half2float(y[iy]) : 0.0f);
vals[col] = val;
max_val = max(max_val, val);
}
Expand Down Expand Up @@ -7585,7 +7585,7 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
}

static void soft_max_f16_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
static void soft_max_f16_cuda(const float * x, const half * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
int nth = WARP_SIZE;
while (nth < ncols_x/2 && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
const dim3 block_dims(nth, 1, 1);
Expand Down Expand Up @@ -7628,7 +7628,7 @@ static void soft_max_f16_cuda(const float * x, const float * y, float * dst, con
}
}

static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
static void soft_max_f32_cuda(const float * x, const half * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
int nth = WARP_SIZE;
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
const dim3 block_dims(nth, 1, 1);
Expand Down Expand Up @@ -9060,7 +9060,7 @@ static void ggml_cuda_op_soft_max(
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);

GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16); // src1 contains mask and it is optional

const int64_t ne00 = src0->ne[0];
const int64_t nrows_x = ggml_nrows(src0);
Expand All @@ -9080,9 +9080,9 @@ static void ggml_cuda_op_soft_max(
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && CUDART_VERSION >= CUDART_HMAX

if (use_f16_soft_max) {
soft_max_f16_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
soft_max_f16_cuda(src0_dd, src1 ? (const half *) src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
} else {
soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
soft_max_f32_cuda(src0_dd, src1 ? (const half *) src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
}

(void) dst;
Expand Down
6 changes: 6 additions & 0 deletions ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -1187,6 +1187,8 @@ static bool ggml_metal_graph_compute(
} break;
case GGML_OP_SOFT_MAX:
{
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16);

int nth = 32; // SIMD width

id<MTLComputePipelineState> pipeline = nil;
Expand Down Expand Up @@ -2213,6 +2215,10 @@ static bool ggml_metal_graph_compute(

id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;

GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16);
GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
"the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");

const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
const int64_t ne31 = src3 ? src3->ne[1] : 0;
const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
Expand Down
40 changes: 19 additions & 21 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -349,9 +349,9 @@ kernel void kernel_sum_rows(
}

kernel void kernel_soft_max(
device const float * src0,
device const float * src1,
device float * dst,
device const char * src0,
device const char * src1,
device char * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
Expand All @@ -366,9 +366,9 @@ kernel void kernel_soft_max(
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);

device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
device const half * pmask = src1 != src0 ? (device const half *) src1 + i01*ne00 : nullptr;
device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);

// parallel max
float lmax = -INFINITY;
Expand Down Expand Up @@ -435,14 +435,14 @@ kernel void kernel_soft_max(
}

kernel void kernel_soft_max_4(
device const float * src0,
device const float * src1,
device float * dst,
device const char * src0,
device const char * src1,
device char * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant float & scale,
threadgroup float * buf [[threadgroup(0)]],
threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]],
Expand All @@ -452,15 +452,15 @@ kernel void kernel_soft_max_4(
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);

device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
device const half4 * pmask = src1 != src0 ? (device const half4 *) src1 + i01*ne00/4 : nullptr;
device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;

// parallel max
float4 lmax4 = -INFINITY;

for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f));
lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4) (pmask ? pmask[i00] : 0.0f));
}

const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
Expand All @@ -486,7 +486,7 @@ kernel void kernel_soft_max_4(
// parallel sum
float4 lsum4 = 0.0f;
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4) (pmask ? pmask[i00] : 0.0f)) - max_val);
lsum4 += exp_psrc4;
pdst4[i00] = exp_psrc4;
}
Expand Down Expand Up @@ -2144,13 +2144,11 @@ kernel void kernel_flash_attn_ext_f16(
}
}

const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1;

// pointer to the mask
device const float * mp = (device const float *) (mask + (ir%ne31)*nb31);
device const half * mp = (device const half *) (mask + iq1*nb31);

// prepare diagonal scale matrix
simdgroup_float8x8 mscale(scale);
simdgroup_half8x8 mscale(scale);

// loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns
Expand All @@ -2176,8 +2174,8 @@ kernel void kernel_flash_attn_ext_f16(

// mqk = mqk*scale + mask
for (int64_t j = 0; j < Q8; ++j) {
simdgroup_float8x8 mm;
simdgroup_load(mm, mp + 8*j*(nb31/sizeof(float)) + ic + 8*cc, nb31/sizeof(float), 0, false);
simdgroup_half8x8 mm;
simdgroup_load(mm, mp + 8*j*(nb31/sizeof(half)) + ic + 8*cc, nb31/sizeof(half), 0, false);
simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm);

simdgroup_store(mqk[j], ss + 8*j*T + 8*cc, T, 0, false);
Expand Down
13 changes: 9 additions & 4 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -5085,6 +5085,7 @@ static struct ggml_tensor * ggml_soft_max_impl(
bool inplace) {
GGML_ASSERT(ggml_is_contiguous(a));
if (mask) {
GGML_ASSERT(mask->type == GGML_TYPE_F16);
GGML_ASSERT(ggml_is_contiguous(mask));
GGML_ASSERT(mask->ne[2] == 1);
GGML_ASSERT(mask->ne[3] == 1);
Expand Down Expand Up @@ -5854,6 +5855,8 @@ struct ggml_tensor * ggml_flash_attn_ext(
GGML_ASSERT(ggml_is_contiguous(mask));
GGML_ASSERT(mask->ne[2] == 1);
GGML_ASSERT(mask->ne[3] == 1);
GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
"the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");
//GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
}

Expand Down Expand Up @@ -11552,12 +11555,14 @@ static void ggml_compute_forward_soft_max_f32(
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);

// broadcast the mask across rows
float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL;
ggml_fp16_t * mp = src1 ? (ggml_fp16_t *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL;

ggml_vec_cpy_f32 (nc, wp, sp);
ggml_vec_scale_f32(nc, wp, scale);
if (mp) {
ggml_vec_acc_f32(nc, wp, mp);
for (int i = 0; i < nc; ++i) {
wp[i] += GGML_FP16_TO_FP32(mp[i]);
}
}

#ifndef NDEBUG
Expand Down Expand Up @@ -13760,7 +13765,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(

memset(V16, 0, D*sizeof(ggml_fp16_t));

const float * mp = mask ? (float *)((char *) mask->data + (ir%mask->ne[1])*mask->nb[1]) : NULL;
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;

// k indices
const int ik3 = iq3 / rk3;
Expand All @@ -13774,7 +13779,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
// loop over n_kv and n_head_kv
// ref: https://arxiv.org/pdf/2112.05682.pdf
for (int64_t ic = 0; ic < nek1; ++ic) {
const float mv = mp ? mp[ic] : 0.0f;
const float mv = mp ? GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
if (mv == -INFINITY) {
continue;
}
Expand Down
12 changes: 7 additions & 5 deletions ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -1646,11 +1646,13 @@ extern "C" {
struct ggml_tensor * v,
bool masked);

// q: [n_embd, n_batch, n_head, 1]
// k: [n_embd, n_kv, n_head_kv, 1]
// v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !!
// mask: [n_kv, n_batch, 1, 1]
// res: [n_embd, n_head, n_batch, 1] !! permuted !!
#define GGML_KQ_MASK_PAD 32

// q: [n_embd, n_batch, n_head, 1]
// k: [n_embd, n_kv, n_head_kv, 1]
// v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !!
// mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
// res: [n_embd, n_head, n_batch, 1] !! permuted !!
GGML_API struct ggml_tensor * ggml_flash_attn_ext(
struct ggml_context * ctx,
struct ggml_tensor * q,
Expand Down
Loading

0 comments on commit 8ad92dc

Please sign in to comment.