diff --git a/common/common.cpp b/common/common.cpp index 06f252ea6914b9..87609d359a312a 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -73,7 +73,7 @@ using json = nlohmann::ordered_json; int32_t get_num_physical_cores() { -#ifdef __linux__ +#if defined(__linux__) || defined(__COSMOPOLITAN__) // enumerate the set of thread siblings, num entries is num cores std::unordered_set siblings; for (uint32_t cpu=0; cpu < UINT32_MAX; ++cpu) { @@ -108,7 +108,7 @@ int32_t get_num_physical_cores() { return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4; } -#if defined(__x86_64__) && defined(__linux__) && !defined(__ANDROID__) +#if defined(__x86_64__) && (defined(__linux__) || defined(__COSMOPOLITAN__)) && !defined(__ANDROID__) #include static void cpuid(unsigned leaf, unsigned subleaf, @@ -162,7 +162,7 @@ static int count_math_cpus(int cpu_count) { * Returns number of CPUs on system that are useful for math. */ int get_math_cpu_count() { -#if defined(__x86_64__) && defined(__linux__) && !defined(__ANDROID__) +#if defined(__x86_64__) && (defined(__linux__) || defined(__COSMOPOLITAN__)) && !defined(__ANDROID__) int cpu_count = sysconf(_SC_NPROCESSORS_ONLN); if (cpu_count < 1) { return get_num_physical_cores(); diff --git a/ggml.c b/ggml.c index 086db96af7fcd1..88fd4823db601e 100644 --- a/ggml.c +++ b/ggml.c @@ -11003,11 +11003,14 @@ static void ggml_compute_forward_mul_mat_id( const struct ggml_tensor * src1 = dst->src[1]; const struct ggml_tensor * ids = dst->src[2]; - GGML_TENSOR_BINARY_OP_LOCALS + if (llamafile_mixmul(params, src0, src1, ids, dst)) + return; const int ith = params->ith; const int nth = params->nth; + GGML_TENSOR_BINARY_OP_LOCALS + const enum ggml_type type = src0->type; const bool src1_cont = ggml_is_contiguous(src1); @@ -18504,6 +18507,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur = 0; const struct ggml_tensor * src0 = node->src[0]; const struct ggml_tensor * src1 = node->src[1]; + const struct ggml_tensor * src2 = node->src[2]; const enum ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type; if (src1->type != vec_dot_type) { cur += ggml_row_size(vec_dot_type, ggml_nelements(src1)); @@ -18512,6 +18516,8 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur += GGML_PAD(cur, sizeof(int64_t)); // align cur += n_as * sizeof(int64_t); // matrix_row_counts cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows + size_t cur2 = llamafile_mixmul_needs(src0, src1, src2); + cur = cur > cur2 ? cur : cur2; } break; case GGML_OP_OUT_PROD: { diff --git a/sgemm.cpp b/sgemm.cpp index 531e12af361ccf..e572d3ff25ec53 100644 --- a/sgemm.cpp +++ b/sgemm.cpp @@ -54,6 +54,10 @@ #include "ggml-impl.h" #include "ggml-quants.h" +#define ROW_ALIGN 64 +#define MATRIX_ALIGN 4096 +#define MAX_ALIGN 4096 + #ifdef _MSC_VER #define NOINLINE __declspec(noinline) #else @@ -66,14 +70,29 @@ #define VECTOR_REGISTERS 16 #endif -#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) - namespace { inline float unhalf(ggml_fp16_t d) { return GGML_FP16_TO_FP32(d); } +//////////////////////////////////////////////////////////////////////////////////////////////////// +// MATRIX MEMORY INDEXING + +#define NCA 1 +#define NCB 2 +#define NCC 4 + +#define INDEX(A, lda, j, i) (CONFIG & NC##A ? ((T##A *const *)A)[j] + i : A + lda * (j) + i) + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// GGML TYPE TRAITS + +template struct ggml_type_trait; +template<> struct ggml_type_trait { static constexpr ggml_type id = GGML_TYPE_F32; }; +template<> struct ggml_type_trait { static constexpr ggml_type id = GGML_TYPE_F16; }; +template<> struct ggml_type_trait { static constexpr ggml_type id = GGML_TYPE_Q8_0; }; + //////////////////////////////////////////////////////////////////////////////////////////////////// // VECTORIZED ARITHMETIC OPERATIONS @@ -118,6 +137,21 @@ inline U madd(T a, T b, U c) { return add(mul(a, b), c); } +/** + * Computes a * b + c with error correction. + * + * @see W. Kahan, "Further remarks on reducing truncation errors," + * Communications of the ACM, vol. 8, no. 1, p. 40, Jan. 1965, + * doi: 10.1145/363707.363723. + */ +template +inline U madder(T a, T b, U c, U *e) { + U y = sub(mul(a, b), *e); + U t = add(c, y); + *e = sub(sub(t, c), y); + return t; +} + #if defined(__FMA__) #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) template <> @@ -136,10 +170,10 @@ inline __m512 madd(__m512 a, __m512 b, __m512 c) { #if defined(__ARM_FEATURE_FMA) template <> inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) { - return vfmaq_f32(c, b, a); + return vfmaq_f32(c, a, b); } -#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) -template <> +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(__clang__) +template <> // this specialization chops gcc 12.3 performance in half inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) { return vfmaq_f16(c, b, a); } @@ -157,6 +191,7 @@ inline float hsum(float32x4_t x) { #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) inline float hsum(float16x8_t x) { + // todo: works great for clang but produces sketchy code with gcc 12.3 return vaddvq_f32(vaddq_f32(vcvt_f32_f16(vget_low_f16(x)), vcvt_f32_f16(vget_high_f16(x)))); } @@ -240,14 +275,10 @@ template <> inline __m512 load(const ggml_fp16_t *p) { //////////////////////////////////////////////////////////////////////////////////////////////////// // FLOATING POINT MATRIX MULTIPLICATION -template +template class tinyBLAS { public: - tinyBLAS(int k, - const TA *A, int lda, - const TB *B, int ldb, - TC *C, int ldc, - int ith, int nth) + tinyBLAS(int k, const TA *A, int lda, const TB *B, int ldb, TC *C, int ldc, int ith, int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { } @@ -262,39 +293,35 @@ class tinyBLAS { switch ((std::min(m - m0, 5) << 4) | std::min(n - n0, 5)) { #if VECTOR_REGISTERS == 32 case 0x55: + case 0x54: mc = 5; - nc = 5; - gemm<5, 5>(m0, m, n0, n); + nc = 4; + gemm<5, 4, false>(m0, m, n0, n); break; case 0x45: mc = 4; nc = 5; - gemm<4, 5>(m0, m, n0, n); - break; - case 0x54: - mc = 5; - nc = 4; - gemm<5, 4>(m0, m, n0, n); + gemm<4, 5, false>(m0, m, n0, n); break; case 0x44: mc = 4; nc = 4; - gemm<4, 4>(m0, m, n0, n); + gemm<4, 4, false>(m0, m, n0, n); break; case 0x53: mc = 5; nc = 3; - gemm<5, 3>(m0, m, n0, n); + gemm<5, 3, false>(m0, m, n0, n); break; case 0x35: mc = 3; nc = 5; - gemm<3, 5>(m0, m, n0, n); + gemm<3, 5, false>(m0, m, n0, n); break; case 0x43: mc = 4; nc = 3; - gemm<4, 3>(m0, m, n0, n); + gemm<4, 3, false>(m0, m, n0, n); break; #else case 0x55: @@ -305,99 +332,99 @@ class tinyBLAS { case 0x43: mc = 4; nc = 3; - gemm<4, 3>(m0, m, n0, n); + gemm<4, 3, false>(m0, m, n0, n); break; case 0x35: #endif case 0x34: mc = 3; nc = 4; - gemm<3, 4>(m0, m, n0, n); + gemm<3, 4, false>(m0, m, n0, n); break; case 0x52: mc = 5; nc = 2; - gemm<5, 2>(m0, m, n0, n); + gemm<5, 2, false>(m0, m, n0, n); break; case 0x33: mc = 3; nc = 3; - gemm<3, 3>(m0, m, n0, n); + gemm<3, 3, false>(m0, m, n0, n); break; case 0x25: mc = 2; nc = 5; - gemm<2, 5>(m0, m, n0, n); + gemm<2, 5, false>(m0, m, n0, n); break; case 0x42: mc = 4; nc = 2; - gemm<4, 2>(m0, m, n0, n); + gemm<4, 2, false>(m0, m, n0, n); break; case 0x24: mc = 2; nc = 4; - gemm<2, 4>(m0, m, n0, n); + gemm<2, 4, false>(m0, m, n0, n); break; case 0x32: mc = 3; nc = 2; - gemm<3, 2>(m0, m, n0, n); + gemm<3, 2, true>(m0, m, n0, n); break; case 0x23: mc = 2; nc = 3; - gemm<2, 3>(m0, m, n0, n); + gemm<2, 3, true>(m0, m, n0, n); break; case 0x51: mc = 5; nc = 1; - gemm<5, 1>(m0, m, n0, n); + gemm<5, 1, true>(m0, m, n0, n); break; case 0x41: mc = 4; nc = 1; - gemm<4, 1>(m0, m, n0, n); + gemm<4, 1, true>(m0, m, n0, n); break; case 0x22: mc = 2; nc = 2; - gemm<2, 2>(m0, m, n0, n); + gemm<2, 2, true>(m0, m, n0, n); break; case 0x15: mc = 1; nc = 5; - gemm<1, 5>(m0, m, n0, n); + gemm<1, 5, true>(m0, m, n0, n); break; case 0x14: mc = 1; nc = 4; - gemm<1, 4>(m0, m, n0, n); + gemm<1, 4, true>(m0, m, n0, n); break; case 0x31: mc = 3; nc = 1; - gemm<3, 1>(m0, m, n0, n); + gemm<3, 1, true>(m0, m, n0, n); break; case 0x13: mc = 1; nc = 3; - gemm<1, 3>(m0, m, n0, n); + gemm<1, 3, true>(m0, m, n0, n); break; case 0x21: mc = 2; nc = 1; - gemm<2, 1>(m0, m, n0, n); + gemm<2, 1, true>(m0, m, n0, n); break; case 0x12: mc = 1; nc = 2; - gemm<1, 2>(m0, m, n0, n); + gemm<1, 2, true>(m0, m, n0, n); break; case 0x11: mc = 1; nc = 1; - gemm<1, 1>(m0, m, n0, n); + gemm<1, 1, true>(m0, m, n0, n); break; default: return; @@ -408,10 +435,10 @@ class tinyBLAS { mnpack(m0, m, np, n); } - template + template NOINLINE void gemm(int m0, int m, int n0, int n) { - int ytiles = (m - m0) / RM; - int xtiles = (n - n0) / RN; + int ytiles = RM > 1 ? (m - m0) / RM : 1; + int xtiles = RN > 1 ? (n - n0) / RN : 1; int tiles = xtiles * ytiles; int duty = (tiles + nth - 1) / nth; int start = duty * ith; @@ -422,15 +449,21 @@ class tinyBLAS { int ii = m0 + job / xtiles * RM; int jj = n0 + job % xtiles * RN; D Cv[RN][RM] = {}; + D Ce[RN][RM] = {}; for (int l = 0; l < k; l += KN) for (int j = 0; j < RN; ++j) for (int i = 0; i < RM; ++i) - Cv[j][i] = madd(load(A + lda * (ii + i) + l), - load(B + ldb * (jj + j) + l), - Cv[j][i]); + if (KAHAN) + Cv[j][i] = madder(load(INDEX(A, lda, ii + i, l)), + load(INDEX(B, ldb, jj + j, l)), + Cv[j][i], &Ce[j][i]); + else + Cv[j][i] = madd(load(INDEX(A, lda, ii + i, l)), + load(INDEX(B, ldb, jj + j, l)), + Cv[j][i]); for (int j = 0; j < RN; ++j) for (int i = 0; i < RM; ++i) - C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); + *INDEX(C, ldc, jj + j, ii + i) = hsum(Cv[j][i]); } } @@ -449,14 +482,11 @@ class tinyBLAS { // QUANT ZERO MATRIX MULTIPLICATION #if defined(__ARM_FEATURE_DOTPROD) -template +template class tinyBLAS_Q0_ARM { public: - tinyBLAS_Q0_ARM(int k, - const TA *A, int lda, - const block_q8_0 *B, int ldb, - float *C, int ldc, - int ith, int nth) + tinyBLAS_Q0_ARM(int k, const TA *A, int lda, const TB *B, int ldb, TC *C, int ldc, int ith, + int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { } @@ -525,8 +555,8 @@ class tinyBLAS_Q0_ARM { template NOINLINE void gemm(int m0, int m, int n0, int n) { - int ytiles = (m - m0) / RM; - int xtiles = (n - n0) / RN; + int ytiles = RM > 1 ? (m - m0) / RM : 1; + int xtiles = RN > 1 ? (n - n0) / RN : 1; int tiles = xtiles * ytiles; int duty = (tiles + nth - 1) / nth; int start = duty * ith; @@ -540,18 +570,18 @@ class tinyBLAS_Q0_ARM { for (int l = 0; l < k; ++l) for (int j = 0; j < RN; ++j) for (int i = 0; i < RM; ++i) - Cv[j][i] = vmlaq_n_f32(Cv[j][i], - vcvtq_f32_s32(vdotq_s32( - vdotq_s32(vdupq_n_s32(0), - load_lo(A + lda * (ii + i) + l), - load_lo(B + ldb * (jj + j) + l)), - load_hi(A + lda * (ii + i) + l), - load_hi(B + ldb * (jj + j) + l))), - unhalf(A[lda * (ii + i) + l].d) * - unhalf(B[ldb * (jj + j) + l].d)); + Cv[j][i] = vmlaq_n_f32( + Cv[j][i], + vcvtq_f32_s32(vdotq_s32(vdotq_s32(vdupq_n_s32(0), + load_lo(INDEX(A, lda, ii + i, l)), + load_lo(INDEX(B, ldb, jj + j, l))), + load_hi(INDEX(A, lda, ii + i, l)), + load_hi(INDEX(B, ldb, jj + j, l)))), + (unhalf(INDEX(A, lda, ii + i, l)->d) * + unhalf(INDEX(B, ldb, jj + j, l)->d))); for (int j = 0; j < RN; ++j) for (int i = 0; i < RM; ++i) - C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); + *INDEX(C, ldc, jj + j, ii + i) = hsum(Cv[j][i]); } } @@ -564,19 +594,17 @@ class tinyBLAS_Q0_ARM { } inline int8x16_t load_lo(const block_q4_0 *b) { - return vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vld1q_u8(b->qs), - vdupq_n_u8(0x0f))), + return vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vld1q_u8(b->qs), vdupq_n_u8(0x0f))), vdupq_n_s8(0x8)); } inline int8x16_t load_hi(const block_q4_0 *b) { - return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)), - vdupq_n_s8(0x8)); + return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)), vdupq_n_s8(0x8)); } const TA *const A; - const block_q8_0 *const B; - float *const C; + const TB *const B; + TC *const C; const int k; const int lda; const int ldb; @@ -587,14 +615,11 @@ class tinyBLAS_Q0_ARM { #endif // __ARM_FEATURE_DOTPROD #if defined(__AVX2__) || defined(__AVX512F__) -template +template class tinyBLAS_Q0_AVX2 { public: - tinyBLAS_Q0_AVX2(int k, - const TA *A, int lda, - const TB *B, int ldb, - TC *C, int ldc, - int ith, int nth) + tinyBLAS_Q0_AVX2(int k, const TA *A, int lda, const TB *B, int ldb, TC *C, int ldc, int ith, + int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { } @@ -715,8 +740,8 @@ class tinyBLAS_Q0_AVX2 { template NOINLINE void gemm(int m0, int m, int n0, int n) { - int ytiles = (m - m0) / RM; - int xtiles = (n - n0) / RN; + int ytiles = RM > 1 ? (m - m0) / RM : 1; + int xtiles = RN > 1 ? (n - n0) / RN : 1; int tiles = xtiles * ytiles; int duty = (tiles + nth - 1) / nth; int start = duty * ith; @@ -730,16 +755,16 @@ class tinyBLAS_Q0_AVX2 { for (int l = 0; l < k; ++l) for (int j = 0; j < RN; ++j) for (int i = 0; i < RM; ++i) - Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) * - unhalf(B[ldb * (jj + j) + l].d)), - updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), - load(A + lda * (ii + i) + l)), - _mm256_sign_epi8(load(B + ldb * (jj + j) + l), - load(A + lda * (ii + i) + l))), + Cv[j][i] = madd(_mm256_set1_ps(unhalf(INDEX(A, lda, ii + i, l)->d) * + unhalf(INDEX(B, ldb, jj + j, l)->d)), + updot(_mm256_sign_epi8(load(INDEX(A, lda, ii + i, l)), + load(INDEX(A, lda, ii + i, l))), + _mm256_sign_epi8(load(INDEX(B, ldb, jj + j, l)), + load(INDEX(A, lda, ii + i, l)))), Cv[j][i]); for (int j = 0; j < RN; ++j) for (int i = 0; i < RM; ++i) - C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); + *INDEX(C, ldc, jj + j, ii + i) = hsum(Cv[j][i]); } } @@ -763,9 +788,9 @@ class tinyBLAS_Q0_AVX2 { static inline __m256i denibble(const uint8_t *p) { __m128i x = _mm_loadu_si128((const __m128i *)p); - return _mm256_and_si256(_mm256_set1_epi8(15), - _mm256_insertf128_si256(_mm256_castsi128_si256(x), - _mm_srli_epi16(x, 4), 1)); + return _mm256_and_si256( + _mm256_set1_epi8(15), + _mm256_insertf128_si256(_mm256_castsi128_si256(x), _mm_srli_epi16(x, 4), 1)); } const TA *const A; @@ -839,21 +864,15 @@ bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B, #if defined(__AVX512F__) if (k % 16) return false; - tinyBLAS<16, __m512, __m512, float, float, float> tb{ - k, (const float *)A, lda, - (const float *)B, ldb, - (float *)C, ldc, - ith, nth}; + tinyBLAS<0, 16, __m512, __m512, float, float, float> tb{ + k, (const float *)A, lda, (const float *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #elif defined(__AVX__) || defined(__AVX2__) if (k % 8) return false; - tinyBLAS<8, __m256, __m256, float, float, float> tb{ - k, (const float *)A, lda, - (const float *)B, ldb, - (float *)C, ldc, - ith, nth}; + tinyBLAS<0, 8, __m256, __m256, float, float, float> tb{ + k, (const float *)A, lda, (const float *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #elif defined(__ARM_NEON) @@ -861,11 +880,8 @@ bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B, return false; if (k % 4) return false; - tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{ - k, (const float *)A, lda, - (const float *)B, ldb, - (float *)C, ldc, - ith, nth}; + tinyBLAS<0, 4, float32x4_t, float32x4_t, float, float, float> tb{ + k, (const float *)A, lda, (const float *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #else @@ -879,11 +895,8 @@ bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B, return false; if (Btype != GGML_TYPE_F32) return false; - tinyBLAS<16, __m512, __m512, ggml_fp16_t, float, float> tb{ - k, (const ggml_fp16_t *)A, lda, - (const float *)B, ldb, - (float *)C, ldc, - ith, nth}; + tinyBLAS<0, 16, __m512, __m512, ggml_fp16_t, float, float> tb{ + k, (const ggml_fp16_t *)A, lda, (const float *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__) @@ -891,11 +904,8 @@ bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B, return false; if (Btype != GGML_TYPE_F32) return false; - tinyBLAS<8, __m256, __m256, ggml_fp16_t, float, float> tb{ - k, (const ggml_fp16_t *)A, lda, - (const float *)B, ldb, - (float *)C, ldc, - ith, nth}; + tinyBLAS<0, 8, __m256, __m256, ggml_fp16_t, float, float> tb{ + k, (const ggml_fp16_t *)A, lda, (const float *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) @@ -905,11 +915,8 @@ bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B, return false; if (Btype != GGML_TYPE_F16) return false; - tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{ - k, (const ggml_fp16_t *)A, lda, - (const ggml_fp16_t *)B, ldb, - (float *)C, ldc, - ith, nth}; + tinyBLAS<0, 8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{ + k, (const ggml_fp16_t *)A, lda, (const ggml_fp16_t *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #elif defined(__ARM_NEON) && !defined(_MSC_VER) @@ -917,11 +924,8 @@ bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B, return false; if (Btype != GGML_TYPE_F32) return false; - tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{ - k, (const ggml_fp16_t *)A, lda, - (const float *)B, ldb, - (float *)C, ldc, - ith, nth}; + tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{ + k, (const ggml_fp16_t *)A, lda, (const float *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #else @@ -931,21 +935,15 @@ bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B, case GGML_TYPE_Q8_0: { if (Btype != GGML_TYPE_Q8_0) - return false; + return false; #if defined(__AVX2__) || defined(__AVX512F__) - tinyBLAS_Q0_AVX2 tb{ - k, (const block_q8_0 *)A, lda, - (const block_q8_0 *)B, ldb, - (float *)C, ldc, - ith, nth}; + tinyBLAS_Q0_AVX2<0, block_q8_0, block_q8_0, float> tb{ + k, (const block_q8_0 *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #elif defined(__ARM_FEATURE_DOTPROD) - tinyBLAS_Q0_ARM tb{ - k, (const block_q8_0 *)A, lda, - (const block_q8_0 *)B, ldb, - (float *)C, ldc, - ith, nth}; + tinyBLAS_Q0_ARM<0, block_q8_0, block_q8_0, float> tb{ + k, (const block_q8_0 *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #else @@ -957,19 +955,13 @@ bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B, if (Btype != GGML_TYPE_Q8_0) return false; #if defined(__AVX2__) || defined(__AVX512F__) - tinyBLAS_Q0_AVX2 tb{ - k, (const block_q4_0 *)A, lda, - (const block_q8_0 *)B, ldb, - (float *)C, ldc, - ith, nth}; + tinyBLAS_Q0_AVX2<0, block_q4_0, block_q8_0, float> tb{ + k, (const block_q4_0 *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #elif defined(__ARM_FEATURE_DOTPROD) - tinyBLAS_Q0_ARM tb{ - k, (const block_q4_0 *)A, lda, - (const block_q8_0 *)B, ldb, - (float *)C, ldc, - ith, nth}; + tinyBLAS_Q0_ARM<0, block_q4_0, block_q8_0, float> tb{ + k, (const block_q4_0 *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #else @@ -997,3 +989,358 @@ bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B, (void)Btype; (void)Ctype; } + +// +// _ _ ___ _ _ ___ +// | |_(_)_ _ _ _| _ ) | /_\ / __| +// | _| | ' \ || | _ \ |__ / _ \\__ \. +// \__|_|_||_\_, |___/____/_/ \_\___/ +// |__/ +// +// MIXTURE OF EXPERTS TENSOR MULTIPLICATION +// +// +// SHAPES +// +// - weights [cols, rows, experts] +// - thought [cols, tasks, tokens] w/ tasks ≤ thinkers +// - result [rows, thinkers, tokens] w/ thinkers ≤ experts +// - plan [thinkers, tokens] w/ i32 < experts +// +// DEFINITION +// +// for thinker in range(thinkers): +// for token in range(tokens): +// for row in range(rows): +// c = 0 +// for col in range(cols): +// expert = plan[token][thinker] +// a = weights[expert][row][col] +// b = thought[token][thinker % tasks][col] +// c += a * b +// result[token][thinker][row] = c +// +// REGULARITIES +// +// - tokens can be odd +// - thinkers is usually 2 +// - tasks is usually 1 or 2 +// - cols should be a multiple of 64 +// - rows should be a multiple of 64 +// - experts is usually 8 but could be 60 +// - tokens is always 1 for token generation +// - tokens can be huge for prompt processing +// +// EXAMPLE +// +// mixtral 8x7b w/ 217 token prompt +// +// | ne*0 ne*1 ne*2 ne*3 | nb*0 nb*1 nb*2 nb*3 | type +// ========================================================================= +// weights | 16384 6144 8 1 | 18 0x2400 0x3600000 0x1b000000 | q4_0 +// thought | 16384 2 217 1 | 4 0x10000 0x20000 0x1b20000 | f32 +// result | 6144 2 217 1 | 4 0x6000 0xc000 0xa2c000 | f32 +// plan | 2 217 1 1 | 4 0x20 0x1b20 0x1b20 | i32 +// + +namespace { +class MixMul { + public: + MixMul(const ggml_compute_params *params, const ggml_tensor *weights, + const ggml_tensor *thought, const ggml_tensor *plan, ggml_tensor *result) + : params(params), + weights(weights), + thought(thought), + plan(plan), + result(result), + rows(weights->ne[1]), + cols(weights->ne[0]), + experts(weights->ne[2]), + thinkers(plan->ne[0]), + tasks(thought->ne[1]), + tokens(thought->ne[2]), + ldq((cols * 2 + ROW_ALIGN - 1) & -ROW_ALIGN), + wdata_((char *)(((uintptr_t)params->wdata + MAX_ALIGN - 1) & -MAX_ALIGN)), + allocated_(0) { + } + + bool allocate_shared_memory() { + if (!(quantized_thought_ = allocate(MATRIX_ALIGN, tokens * tasks * ldq))) + return false; + if (!(rowptr_result_ = allocate(ROW_ALIGN, experts * tokens * thinkers))) + return false; + if (!(rowptr_thought_ = allocate(ROW_ALIGN, experts * tokens * thinkers))) + return false; + if (!(rowptr_count_ = allocate(sizeof(int), experts))) + return false; + return true; + } + + size_t get_allocated_bytes() { + return (wdata_ - (char *)params->wdata) + allocated_; + } + + bool mixmul() { + + // invariants + assert(tasks <= thinkers); + assert(thinkers <= experts); + assert(tokens == plan->ne[1]); + assert(rows == result->ne[0]); + assert(cols == thought->ne[0]); + assert(tokens == result->ne[2]); + assert(thinkers == result->ne[1]); + + // dimensionality + assert(plan->ne[2] == 1); + assert(plan->ne[3] == 1); + assert(result->ne[3] == 1); + assert(weights->ne[3] == 1); + assert(thought->ne[3] == 1); + + // miscellaneous + assert(params->nth > 0); + assert(params->ith < params->nth); + assert(plan->type == GGML_TYPE_I32); + + // supported types + if (result->type != GGML_TYPE_F32) + return false; + + // check nb01 is convertible to lda + if (weights->nb[1] % ggml_type_size(weights->type)) + return false; + + // no support for column strides + if (result->nb[0] != ggml_type_size(result->type)) + return false; + if (thought->nb[0] != ggml_type_size(thought->type)) + return false; + if (weights->nb[0] != ggml_type_size(weights->type)) + return false; + + switch (weights->type) { + + case GGML_TYPE_F32: + if (thought->type != GGML_TYPE_F32) + return false; +#if defined(__AVX512F__) + return mixmat<16, 1, tinyBLAS, + float, float, float>(); +#elif defined(__AVX__) || defined(__AVX2__) + return mixmat<8, 1, tinyBLAS, float, + float, float>(); +#elif defined(__SSE__) + return mixmat<4, 1, tinyBLAS, float, + float, float>(); +#elif defined(__ARM_NEON) + return mixmat<4, 1, + tinyBLAS, + float, float, float>(); +#else + return false; +#endif + + case GGML_TYPE_F16: + if (thought->type != GGML_TYPE_F32 && thought->type != GGML_TYPE_F16) + return false; +#if defined(__AVX512F__) + return mixmat<16, 1, + tinyBLAS, + ggml_fp16_t, ggml_fp16_t, float>(); +#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__) + return mixmat< + 8, 1, tinyBLAS, + ggml_fp16_t, ggml_fp16_t, float>(); +#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) + return mixmat< + 8, 1, + tinyBLAS, + ggml_fp16_t, ggml_fp16_t, float>(); +#elif defined(__ARM_NEON) && !defined(_MSC_VER) + return mixmat< + 4, 1, + tinyBLAS, + ggml_fp16_t, ggml_fp16_t, float>(); +#else + return false; +#endif + + case GGML_TYPE_Q4_0: + if (thought->type != GGML_TYPE_F32 && thought->type != GGML_TYPE_Q8_0) + return false; +#if defined(__AVX2__) || defined(__AVX512F__) + return mixmat<32, 32, tinyBLAS_Q0_AVX2, + block_q4_0, block_q8_0, float>(); +#elif defined(__ARM_FEATURE_DOTPROD) + return mixmat<32, 32, tinyBLAS_Q0_ARM, + block_q4_0, block_q8_0, float>(); +#else + return false; +#endif + + case GGML_TYPE_Q8_0: + if (thought->type != GGML_TYPE_F32 && thought->type != GGML_TYPE_Q8_0) + return false; +#if defined(__AVX2__) || defined(__AVX512F__) + return mixmat<32, 32, tinyBLAS_Q0_AVX2, + block_q8_0, block_q8_0, float>(); +#elif defined(__ARM_FEATURE_DOTPROD) + return mixmat<32, 32, tinyBLAS_Q0_ARM, + block_q8_0, block_q8_0, float>(); +#else + return false; +#endif + + default: + return false; + } + } + + private: + template + bool mixmat() { + if (cols % KN) + return false; + switch (params->type) { + case GGML_TASK_TYPE_INIT: + if (thought->type != ggml_type_trait::id) + quantize_thought(ggml_type_trait::id); + build_row_pointers(ggml_type_trait::id); + return true; + case GGML_TASK_TYPE_COMPUTE: + assert(!(cols % BS)); + assert(!(weights->nb[1] % sizeof(TA))); + for (int expert = 0; expert < experts; ++expert) { + BLAS tb{cols / BS, + (const TA *)((const char *)weights->data + expert * weights->nb[2]), + (int)(weights->nb[1] / sizeof(TA)), + (const TB *)(rowptr_thought_ + expert * tokens * thinkers), 0, + (TC *)(rowptr_result_ + expert * tokens * thinkers), 0, + params->ith, + params->nth}; + tb.matmul(rows, rowptr_count_[expert], GGML_TASK_TYPE_COMPUTE); + } + return true; + default: + return true; + } + } + + void build_row_pointers(ggml_type vec_dot_type) { + for (int expert = params->ith; expert < experts; expert += params->nth) { + int count = 0; + for (int token = 0; token < tokens; ++token) + for (int thinker = 0; thinker < thinkers; ++thinker) + if (expert == *(const int *)((const char *)plan->data + token * plan->nb[1] + + thinker * plan->nb[0])) { + int row = count++; + int idx = expert * thinkers * tokens + row; + rowptr_result_[idx] = + (uintptr_t)((char *)result->data + token * result->nb[2] + + thinker * result->nb[1]); + if (thought->type == vec_dot_type) + rowptr_thought_[idx] = + (uintptr_t)((char *)thought->data + token * thought->nb[2] + + thinker % tasks * thought->nb[1]); + else + rowptr_thought_[idx] = + (uintptr_t)((char *)quantized_thought_ + token * tasks * ldq + + thinker % tasks * ldq); + } + rowptr_count_[expert] = count; + } + } + + void quantize_thought(ggml_type vec_dot_type) { + int chore = 0; + for (int token = 0; token < tokens; ++token) + for (int task = 0; task < tasks; ++task) + if (chore++ % params->nth == params->ith) + quantize_row(quantized_thought_ + token * tasks * ldq + task * ldq, + (const float *)((const char *)thought->data + + token * thought->nb[2] + task * thought->nb[1]), + vec_dot_type); + } + + void quantize_row(void *dst, const float *src, ggml_type type) { + assert((int)ggml_row_size(type, cols) <= ldq); + switch (type) { + case GGML_TYPE_F16: + ggml_fp32_to_fp16_row(src, (ggml_fp16_t *)dst, cols); + break; + case GGML_TYPE_Q8_0: + quantize_row_q8_0((const float *)src, (block_q8_0 *)dst, cols); + break; + default: + GGML_UNREACHABLE(); + } + } + + template + T *allocate(int align, int elems) { + T *res = nullptr; + size_t need = sizeof(T) * elems; + size_t base = allocated_; + base += align - 1; + base &= -align; + size_t toto = base + need; + if (toto >= allocated_ && toto <= params->wsize && elems >= 0) { + res = (T *)(wdata_ + base); + allocated_ = toto; + } + return res; + } + + const ggml_compute_params *const params; + const ggml_tensor *const weights; + const ggml_tensor *const thought; + const ggml_tensor *const plan; + ggml_tensor *const result; + const int rows; + const int cols; + const int experts; + const int thinkers; + const int tasks; + const int tokens; + const int ldq; + + // variables + char *const wdata_; + size_t allocated_; + + // shared memory + int *rowptr_count_ /*[experts]*/; + char *quantized_thought_ /*[tokens][tasks][cols][2]*/; + uintptr_t *rowptr_result_ /*[experts][tokens*thinkers]*/; + uintptr_t *rowptr_thought_ /*[experts][tokens*thinkers]*/; +}; +} // namespace + +/** + * Performs "mixture of experts" tensor multiplication on CPU. + */ +bool llamafile_mixmul(const ggml_compute_params *params, + const ggml_tensor *weights, + const ggml_tensor *thought, + const ggml_tensor *plan, + ggml_tensor *result) { + MixMul mm{params, weights, thought, plan, result}; + return mm.allocate_shared_memory() && mm.mixmul(); +} + +/** + * Returns number of shared memory bytes llamafile_mixmul() needs. + */ +size_t llamafile_mixmul_needs(const ggml_tensor *weights, + const ggml_tensor *thought, + const ggml_tensor *plan) { + ggml_compute_params params{}; + params.wsize = 0x7ffff000; + params.wdata = (void *)0x1000; + MixMul mm{¶ms, weights, thought, plan, 0}; + if (mm.allocate_shared_memory()) + return mm.get_allocated_bytes(); + else + return 0; +} diff --git a/sgemm.h b/sgemm.h index da23b209c4dd5b..5921cbc9930ef2 100644 --- a/sgemm.h +++ b/sgemm.h @@ -1,12 +1,24 @@ #pragma once +#include #include #ifdef __cplusplus extern "C" { #endif +struct ggml_tensor; +struct ggml_compute_params; + bool llamafile_sgemm(int, int, int, const void *, int, const void *, int, void *, int, int, int, int, int, int, int); +bool llamafile_mixmul(const struct ggml_compute_params *, const struct ggml_tensor *, + const struct ggml_tensor *, const struct ggml_tensor *, + struct ggml_tensor *); + +size_t llamafile_mixmul_needs(const struct ggml_tensor *, + const struct ggml_tensor *, + const struct ggml_tensor *); + #ifdef __cplusplus } #endif