Skip to content

Commit

Permalink
Use more Kahan summation
Browse files Browse the repository at this point in the history
  • Loading branch information
jart committed Apr 30, 2024
1 parent 2af3b88 commit 9540b43
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 2 deletions.
36 changes: 34 additions & 2 deletions llama.cpp/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -887,6 +887,7 @@ do { \
#define GGML_F32x4_STORE vst1q_f32
#define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c)
#define GGML_F32x4_ADD vaddq_f32
#define GGML_F32x4_SUB vsubq_f32
#define GGML_F32x4_MUL vmulq_f32
#define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
#define GGML_F32x4_REDUCE(res, x) \
Expand All @@ -913,6 +914,7 @@ do { \
#define GGML_F32_VEC_STORE GGML_F32x4_STORE
#define GGML_F32_VEC_FMA GGML_F32x4_FMA
#define GGML_F32_VEC_ADD GGML_F32x4_ADD
#define GGML_F32_VEC_SUB GGML_F32x4_SUB
#define GGML_F32_VEC_MUL GGML_F32x4_MUL
#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE

Expand All @@ -929,6 +931,7 @@ do { \
#define GGML_F16x8_STORE vst1q_f16
#define GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c)
#define GGML_F16x8_ADD vaddq_f16
#define GGML_F16x8_SUB vsubq_f16
#define GGML_F16x8_MUL vmulq_f16
#define GGML_F16x8_REDUCE(res, x) \
do { \
Expand Down Expand Up @@ -956,6 +959,7 @@ do { \
#define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE(p, r[i])
#define GGML_F16_VEC_FMA GGML_F16x8_FMA
#define GGML_F16_VEC_ADD GGML_F16x8_ADD
#define GGML_F16_VEC_SUB GGML_F16x8_SUB
#define GGML_F16_VEC_MUL GGML_F16x8_MUL
#define GGML_F16_VEC_REDUCE GGML_F16x8_REDUCE
#else
Expand All @@ -972,6 +976,7 @@ do { \
#define GGML_F32Cx4_STORE(x, y) vst1_f16(x, vcvt_f16_f32(y))
#define GGML_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c)
#define GGML_F32Cx4_ADD vaddq_f32
#define GGML_F32Cx4_SUB vsubq_f32
#define GGML_F32Cx4_MUL vmulq_f32
#define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE

Expand All @@ -982,6 +987,7 @@ do { \
#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i])
#define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
#define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
#define GGML_F16_VEC_SUB GGML_F32Cx4_SUB
#define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
#define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
#endif
Expand All @@ -1003,6 +1009,7 @@ do { \
// _mm512_fmadd_ps is defined in AVX512F so no guard is required
#define GGML_F32x16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a)
#define GGML_F32x16_ADD _mm512_add_ps
#define GGML_F32x16_SUB _mm512_sub_ps
#define GGML_F32x16_MUL _mm512_mul_ps
#define GGML_F32x16_REDUCE(res, x) \
do { \
Expand Down Expand Up @@ -1030,6 +1037,7 @@ do { \
#define GGML_F32_VEC_STORE GGML_F32x16_STORE
#define GGML_F32_VEC_FMA GGML_F32x16_FMA
#define GGML_F32_VEC_ADD GGML_F32x16_ADD
#define GGML_F32_VEC_SUB GGML_F32x16_SUB
#define GGML_F32_VEC_MUL GGML_F32x16_MUL
#define GGML_F32_VEC_REDUCE GGML_F32x16_REDUCE

Expand All @@ -1053,6 +1061,7 @@ do { \

#define GGML_F32Cx16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a)
#define GGML_F32Cx16_ADD _mm512_add_ps
#define GGML_F32Cx16_SUB _mm512_sub_ps
#define GGML_F32Cx16_MUL _mm512_mul_ps
#define GGML_F32Cx16_REDUCE(res, x) \
do { \
Expand All @@ -1078,6 +1087,7 @@ do { \
#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx16_STORE(p, r[i])
#define GGML_F16_VEC_FMA GGML_F32Cx16_FMA
#define GGML_F16_VEC_ADD GGML_F32Cx16_ADD
#define GGML_F16_VEC_SUB GGML_F32Cx16_SUB
#define GGML_F16_VEC_MUL GGML_F32Cx16_MUL
#define GGML_F16_VEC_REDUCE GGML_F32Cx16_REDUCE

Expand All @@ -1101,6 +1111,7 @@ do { \
#define GGML_F32x8_FMA(a, b, c) _mm256_add_ps(_mm256_mul_ps(b, c), a)
#endif
#define GGML_F32x8_ADD _mm256_add_ps
#define GGML_F32x8_SUB _mm256_sub_ps
#define GGML_F32x8_MUL _mm256_mul_ps
#define GGML_F32x8_REDUCE(res, x) GGML_F32x8_REDUCE_AVX(res, x)
// TODO: is this optimal ?
Expand All @@ -1112,6 +1123,7 @@ do { \
#define GGML_F32_VEC_STORE GGML_F32x8_STORE
#define GGML_F32_VEC_FMA GGML_F32x8_FMA
#define GGML_F32_VEC_ADD GGML_F32x8_ADD
#define GGML_F32_VEC_SUB GGML_F32x8_SUB
#define GGML_F32_VEC_MUL GGML_F32x8_MUL
#define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE

Expand Down Expand Up @@ -1154,6 +1166,7 @@ static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {

#define GGML_F32Cx8_FMA GGML_F32x8_FMA
#define GGML_F32Cx8_ADD _mm256_add_ps
#define GGML_F32Cx8_SUB _mm256_sub_ps
#define GGML_F32Cx8_MUL _mm256_mul_ps
#define GGML_F32Cx8_REDUCE GGML_F32x8_REDUCE

Expand All @@ -1164,6 +1177,7 @@ static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i])
#define GGML_F16_VEC_FMA GGML_F32Cx8_FMA
#define GGML_F16_VEC_ADD GGML_F32Cx8_ADD
#define GGML_F16_VEC_SUB GGML_F32Cx8_SUB
#define GGML_F16_VEC_MUL GGML_F32Cx8_MUL
#define GGML_F16_VEC_REDUCE GGML_F32Cx8_REDUCE

Expand All @@ -1183,6 +1197,7 @@ static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
#define GGML_F32x4_STORE(p, r) vec_xst(r, 0, p)
#define GGML_F32x4_FMA(a, b, c) vec_madd(b, c, a)
#define GGML_F32x4_ADD vec_add
#define GGML_F32x4_SUB vec_sub
#define GGML_F32x4_MUL vec_mul
#define GGML_F32x4_REDUCE(res, x) \
{ \
Expand Down Expand Up @@ -1211,6 +1226,7 @@ static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
#define GGML_F32_VEC_STORE GGML_F32x4_STORE
#define GGML_F32_VEC_FMA GGML_F32x4_FMA
#define GGML_F32_VEC_ADD GGML_F32x4_ADD
#define GGML_F32_VEC_SUB GGML_F32x4_SUB
#define GGML_F32_VEC_MUL GGML_F32x4_MUL
#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE

Expand Down Expand Up @@ -1249,6 +1265,7 @@ static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
#define GGML_F32x4_STORE wasm_v128_store
#define GGML_F32x4_FMA(a, b, c) wasm_f32x4_add(wasm_f32x4_mul(b, c), a)
#define GGML_F32x4_ADD wasm_f32x4_add
#define GGML_F32x4_SUB wasm_f32x4_sub
#define GGML_F32x4_MUL wasm_f32x4_mul
#define GGML_F32x4_REDUCE(res, x) \
{ \
Expand Down Expand Up @@ -1277,6 +1294,7 @@ static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
#define GGML_F32_VEC_STORE GGML_F32x4_STORE
#define GGML_F32_VEC_FMA GGML_F32x4_FMA
#define GGML_F32_VEC_ADD GGML_F32x4_ADD
#define GGML_F32_VEC_SUB GGML_F32x4_SUB
#define GGML_F32_VEC_MUL GGML_F32x4_MUL
#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE

Expand Down Expand Up @@ -1314,6 +1332,7 @@ inline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) {
#define GGML_F16x4_STORE(x, y) __wasm_f16x4_store(x, y)
#define GGML_F16x4_FMA GGML_F32x4_FMA
#define GGML_F16x4_ADD wasm_f32x4_add
#define GGML_F16x4_SUB wasm_f32x4_sub
#define GGML_F16x4_MUL wasm_f32x4_mul
#define GGML_F16x4_REDUCE(res, x) \
{ \
Expand Down Expand Up @@ -1342,6 +1361,7 @@ inline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) {
#define GGML_F16_VEC_STORE(p, r, i) GGML_F16x4_STORE(p, r[i])
#define GGML_F16_VEC_FMA GGML_F16x4_FMA
#define GGML_F16_VEC_ADD GGML_F16x4_ADD
#define GGML_F16_VEC_SUB GGML_F16x4_SUB
#define GGML_F16_VEC_MUL GGML_F16x4_MUL
#define GGML_F16_VEC_REDUCE GGML_F16x4_REDUCE

Expand All @@ -1366,6 +1386,7 @@ inline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) {
#define GGML_F32x4_FMA(a, b, c) _mm_add_ps(_mm_mul_ps(b, c), a)
#endif
#define GGML_F32x4_ADD _mm_add_ps
#define GGML_F32x4_SUB _mm_sub_ps
#define GGML_F32x4_MUL _mm_mul_ps
#define GGML_F32x4_REDUCE(res, x) \
{ \
Expand Down Expand Up @@ -1393,6 +1414,7 @@ inline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) {
#define GGML_F32_VEC_STORE GGML_F32x4_STORE
#define GGML_F32_VEC_FMA GGML_F32x4_FMA
#define GGML_F32_VEC_ADD GGML_F32x4_ADD
#define GGML_F32_VEC_SUB GGML_F32x4_SUB
#define GGML_F32_VEC_MUL GGML_F32x4_MUL
#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE

Expand Down Expand Up @@ -1430,6 +1452,7 @@ static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) {
#define GGML_F32Cx4_STORE(x, y) __sse_f16x4_store(x, y)
#define GGML_F32Cx4_FMA GGML_F32x4_FMA
#define GGML_F32Cx4_ADD _mm_add_ps
#define GGML_F32Cx4_SUB _mm_sub_ps
#define GGML_F32Cx4_MUL _mm_mul_ps
#define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE

Expand All @@ -1440,6 +1463,7 @@ static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) {
#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i])
#define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
#define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
#define GGML_F16_VEC_SUB GGML_F32Cx4_SUB
#define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
#define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE

Expand Down Expand Up @@ -1579,6 +1603,7 @@ static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t *
const int np = (n & ~(GGML_F16_STEP - 1));

GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO };
GGML_F16_VEC err[GGML_F16_ARR] = { GGML_F16_VEC_ZERO };

GGML_F16_VEC ax[GGML_F16_ARR];
GGML_F16_VEC ay[GGML_F16_ARR];
Expand All @@ -1588,7 +1613,10 @@ static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t *
ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);

sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]);
GGML_F16_VEC ky = GGML_F16_VEC_SUB(GGML_F16_VEC_MUL(ax[j], ay[j]), err[j]);
GGML_F16_VEC kt = GGML_F16_VEC_ADD(sum[j], ky);
err[j] = GGML_F16_VEC_SUB(GGML_F16_VEC_SUB(kt, sum[j]), ky);
sum[j] = kt;
}
}

Expand Down Expand Up @@ -1623,6 +1651,7 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * re
const int np = (n & ~(GGML_F16_STEP - 1));

GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };
GGML_F16_VEC err[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };

GGML_F16_VEC ax[GGML_F16_ARR];
GGML_F16_VEC ay[GGML_F16_ARR];
Expand All @@ -1634,7 +1663,10 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * re
for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j);

sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]);
GGML_F16_VEC ky = GGML_F16_VEC_SUB(GGML_F16_VEC_MUL(ax[j], ay[j]), err[k][j]);
GGML_F16_VEC kt = GGML_F16_VEC_ADD(sum[k][j], ky);
err[k][j] = GGML_F16_VEC_SUB(GGML_F16_VEC_SUB(kt, sum[k][j]), ky);
sum[k][j] = kt;
}
}
}
Expand Down
6 changes: 6 additions & 0 deletions llamafile/tinyblas_cpu.inc
Original file line number Diff line number Diff line change
Expand Up @@ -1088,6 +1088,9 @@ bool llamafile_sgemm(long m, long n, long k, const void *A, long lda, const void
return false;
}
#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
if (n < 2)
// TODO(jart): Why is ggml_vec_dot_f16_unroll() so fast at matvec?
return false;
if (k % 8)
return false;
if (Btype != GGML_TYPE_F16)
Expand All @@ -1097,6 +1100,9 @@ bool llamafile_sgemm(long m, long n, long k, const void *A, long lda, const void
tb.matmul(m, n, task);
return true;
#elif defined(__ARM_NEON) && !defined(_MSC_VER)
if (n < 2)
// TODO(jart): Why is ggml_vec_dot_f16_unroll() so fast at matvec?
return false;
if (k % 4)
return false;
if (Btype != GGML_TYPE_F32)
Expand Down

0 comments on commit 9540b43

Please sign in to comment.