From fb085fabadb93949b3e148b8dc1e465643048b85 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 1 Feb 2024 15:00:47 +0200 Subject: [PATCH] cuda : fix to F16 scalars + tune warps for RTX 2060 --- ggml-cuda.cu | 94 ++++++++++++++++++++------------------ tests/test-backend-ops.cpp | 14 +++++- 2 files changed, 61 insertions(+), 47 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index bdd50e2b6be4c3..330fc6290effa9 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6491,8 +6491,8 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); { - float S[Q]; - float M[Q]; + half S[Q]; + half M[Q]; for(int i = 0; i < Q; i++) { S[i] = 0.0f; @@ -6579,67 +6579,68 @@ static __global__ void flash_attn_ext_f16( } // used to detect blocks full of -INF - float smax = -INFINITY; + half smax = -INFINITY; // online softmax if (C == 32) { for (int64_t j = 0; j < Q; ++j) { const int64_t p = lane_id; - const float m = M[j]; - const float s = __half2float(ss[j*T + p]); + const half m = M[j]; + const half s = ss[j*T + p]; - smax = warp_reduce_max(max(smax, s)); - M[j] = warp_reduce_max(max(M[j], s)); + smax = warp_reduce_max(__hmax(smax, s)); + M[j] = warp_reduce_max(__hmax(M[j], s)); - const float ms = m == -INFINITY ? 0.0f : expf(m - M[j]); - const float vs = s == -INFINITY ? 0.0f : expf(s - M[j]); + const half ms = __hisinf(m) ? 0.0f : expf(m - M[j]); + const half vs = __hisinf(s) ? 0.0f : expf(s - M[j]); S[j] = S[j]*ms + warp_reduce_sum(vs); // create a QxQ diagonal matrix for rescaling the output if (p == j) { - ss[j*T + C + j] = __float2half(ms); + ss[j*T + C + j] = ms; } // the P matrix from the paper (Q rows, C columns) - ss[j*T + p] = __float2half(vs); + ss[j*T + p] = vs; } } else { for (int64_t j = 0; j < Q; ++j) { - const float m = M[j]; + const half m = M[j]; for (int64_t p = lane_id; p < C; p += NW) { - const float s = __half2float(ss[j*T + p]); + const half s = ss[j*T + p]; - smax = warp_reduce_max(max(smax, s)); - M[j] = warp_reduce_max(max(M[j], s)); + smax = warp_reduce_max(__hmax(smax, s)); + M[j] = warp_reduce_max(__hmax(M[j], s)); } - const float ms = m == -INFINITY ? 0.0f : expf(m - M[j]); + const half ms = __hisinf(m) ? 0.0f : expf(m - M[j]); S[j] = S[j]*ms; // create a QxQ diagonal matrix for rescaling the output if (lane_id == j) { - ss[j*T + C + j] = __float2half(ms); + ss[j*T + C + j] = ms; } for (int64_t p = lane_id; p < C; p += NW) { - const float s = __half2float(ss[j*T + p]); + const half s = ss[j*T + p]; - const float vs = s == -INFINITY ? 0.0f : expf(s - M[j]); + const half vs = __hisinf(s) ? 0.0f : expf(s - M[j]); S[j] = S[j] + warp_reduce_sum(vs); // the P matrix from the paper (Q rows, C columns) - ss[j*T + p] = __float2half(vs); + ss[j*T + p] = vs; } } } + // skip -INF blocks - if (smax == -INFINITY) { + if (__hisinf(smax)) { continue; } @@ -6686,16 +6687,16 @@ static __global__ void flash_attn_ext_f16( // these are needed for reducing the results from the simdgroups (reuse the ss buffer) for (int64_t j = 0; j < Q; ++j) { if (lane_id == 0) { - ss[j*T + 0] = __float2half(S[j]); - ss[j*T + 1] = __float2half(M[j]); + ss[j*T + 0] = S[j]; + ss[j*T + 1] = M[j]; } } } // reduce the warps sequentially for (int64_t sg = 1; sg < num_warps; ++sg) { - float S = 0.0f; - float M = -INFINITY; + half S = 0.0f; + half M = -INFINITY; __syncthreads(); @@ -6713,25 +6714,25 @@ static __global__ void flash_attn_ext_f16( // the first simdgroup accumulates the results from the other simdgroups if (warp_id == 0) { for (int64_t j = 0; j < Q; ++j) { - const float S0 = __half2float(ss[j*T + 0]); - const float S1 = __half2float(ss[j*T + sg*SH + 0]); + const half S0 = ss[j*T + 0]; + const half S1 = ss[j*T + sg*SH + 0]; - const float M0 = __half2float(ss[j*T + 1]); - const float M1 = __half2float(ss[j*T + sg*SH + 1]); + const half M0 = ss[j*T + 1]; + const half M1 = ss[j*T + sg*SH + 1]; - M = max(M0, M1); + M = __hmax(M0, M1); - const float ms0 = M0 == -INFINITY ? 0.0f : expf(M0 - M); - const float ms1 = M1 == -INFINITY ? 0.0f : expf(M1 - M); + const half ms0 = __hisinf(M0) ? 0.0f : expf(M0 - M); + const half ms1 = __hisinf(M1) ? 0.0f : expf(M1 - M); S = S0*ms0 + S1*ms1; if (lane_id == 0) { - ss[j*T + 0] = __float2half(S); - ss[j*T + 1] = __float2half(M); + ss[j*T + 0] = S; + ss[j*T + 1] = M; - ss[j*T + C + j ] = __float2half(ms0); - ss[j*T + C + j + sg*SH] = __float2half(ms1); + ss[j*T + C + j ] = ms0; + ss[j*T + C + j + sg*SH] = ms1; } } @@ -6774,10 +6775,10 @@ static __global__ void flash_attn_ext_f16( // final rescale with 1/S and store to global memory if (warp_id == 0) { for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) { - const float S = __half2float(ss[j*T + 0]); + const half S = ss[j*T + 0]; for (int64_t i = lane_id; i < D; i += NW) { - dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i]) / S; + dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i] / S); } } } @@ -10930,12 +10931,15 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * float scale; memcpy(&scale, KQV->op_params, sizeof(float)); - const int nqpb = 16; // queries per block - const int ncpw = 32; // cache values per warp (does not work for other values) +#define NQPB 16 +#define NCPW 32 + + const int nqpb = NQPB; // queries per block + const int ncpw = NCPW; // cache values per warp (does not work for other values) const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much? // TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why - const int nwarps = Q->ne[1] <= nqpb ? MAX(4, MIN(K->ne[1]/ncpw, nwarps_max)) : 4; + const int nwarps = Q->ne[1] <= nqpb ? MAX(2, MIN(K->ne[1]/ncpw, nwarps_max)) : 2; dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]); dim3 block_dim(32, nwarps, 1); @@ -10945,7 +10949,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * switch (Q->ne[0]) { case 16: - flash_attn_ext_f16<16, 16, 32> + flash_attn_ext_f16<16, NQPB, NCPW> <<>> ( (const char *) src0_extra->data_device[g_main_device], // Query (const char *) src1_extra->data_device[g_main_device], // Key @@ -10962,7 +10966,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * ); break; case 64: - flash_attn_ext_f16<64, 16, 32> + flash_attn_ext_f16<64, NQPB, NCPW> <<>> ( (const char *) src0_extra->data_device[g_main_device], // Query (const char *) src1_extra->data_device[g_main_device], // Key @@ -10979,7 +10983,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * ); break; case 80: - flash_attn_ext_f16<80, 16, 32> + flash_attn_ext_f16<80, NQPB, NCPW> <<>> ( (const char *) src0_extra->data_device[g_main_device], // Query (const char *) src1_extra->data_device[g_main_device], // Key @@ -10996,7 +11000,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * ); break; case 128: - flash_attn_ext_f16<128, 16, 32> + flash_attn_ext_f16<128, NQPB, NCPW> <<>> ( (const char *) src0_extra->data_device[g_main_device], // Query (const char *) src1_extra->data_device[g_main_device], // Key diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index e632142a74a13d..ff207e21b8ec32 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -572,9 +572,19 @@ struct test_case { // duplicate the op size_t target_size = ggml_backend_is_cpu(backend) ? 1ULL << 33 : 1ULL << 35; // 8 GB CPU, 32 GB GPU int n_runs = std::min((size_t)gf->size - gf->n_nodes, target_size / op_size(out)) + 1; +#if 1 for (int i = 1; i < n_runs; i++) { gf->nodes[gf->n_nodes++] = out; } +#else + n_runs = 1000; + int n_nodes = gf->n_nodes; + for (int i = 1; i < n_runs; i++) { + for (int j = 0; j < n_nodes; j++) { + gf->nodes[gf->n_nodes++] = gf->nodes[j]; + } + } +#endif // calculate memory size_t mem = n_runs * op_size(out); @@ -2199,8 +2209,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_leaky_relu()); -#if 0 - for (int hs : { 64, 80, 96, 112, 128, 256, }) { +#if 1 + for (int hs : { 64, 80, 128, }) { for (int nh : { 32, }) { for (int kv : { 512, 1024, 2048, 4096, }) { for (int nb : { 1, 2, 4, 8, 512, 1024, 2048, }) {