diff --git a/ggml.c b/ggml.c index 086db96af7fcd1..88fd4823db601e 100644 --- a/ggml.c +++ b/ggml.c @@ -11003,11 +11003,14 @@ static void ggml_compute_forward_mul_mat_id( const struct ggml_tensor * src1 = dst->src[1]; const struct ggml_tensor * ids = dst->src[2]; - GGML_TENSOR_BINARY_OP_LOCALS + if (llamafile_mixmul(params, src0, src1, ids, dst)) + return; const int ith = params->ith; const int nth = params->nth; + GGML_TENSOR_BINARY_OP_LOCALS + const enum ggml_type type = src0->type; const bool src1_cont = ggml_is_contiguous(src1); @@ -18504,6 +18507,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur = 0; const struct ggml_tensor * src0 = node->src[0]; const struct ggml_tensor * src1 = node->src[1]; + const struct ggml_tensor * src2 = node->src[2]; const enum ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type; if (src1->type != vec_dot_type) { cur += ggml_row_size(vec_dot_type, ggml_nelements(src1)); @@ -18512,6 +18516,8 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur += GGML_PAD(cur, sizeof(int64_t)); // align cur += n_as * sizeof(int64_t); // matrix_row_counts cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows + size_t cur2 = llamafile_mixmul_needs(src0, src1, src2); + cur = cur > cur2 ? cur : cur2; } break; case GGML_OP_OUT_PROD: { diff --git a/sgemm.cpp b/sgemm.cpp index 531e12af361ccf..e65079edf4547e 100644 --- a/sgemm.cpp +++ b/sgemm.cpp @@ -54,6 +54,10 @@ #include "ggml-impl.h" #include "ggml-quants.h" +#define ROW_ALIGN 64 +#define MATRIX_ALIGN 4096 +#define MAX_ALIGN 4096 + #ifdef _MSC_VER #define NOINLINE __declspec(noinline) #else @@ -66,14 +70,45 @@ #define VECTOR_REGISTERS 16 #endif -#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) - namespace { inline float unhalf(ggml_fp16_t d) { return GGML_FP16_TO_FP32(d); } +//////////////////////////////////////////////////////////////////////////////////////////////////// +// MATRIX MEMORY INDEXING + +#define NCA 1 +#define NCB 2 +#define NCC 4 + +#define INDEX(A, lda, j, i) index(A, lda, j, i) + +template +inline T &index(T *A, int lda, int j, int i) { + if (NC) + return ((T **)A)[j][i]; + else + return A[lda * j + i]; +} + +template +inline const T &index(const T *A, int lda, int j, int i) { + if (NC) + return ((const T *const *)A)[j][i]; + else + return A[lda * j + i]; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// GGML TYPE TRAITS + +template struct ggml_type_trait; +template<> struct ggml_type_trait { static constexpr ggml_type id = GGML_TYPE_F32; }; +template<> struct ggml_type_trait { static constexpr ggml_type id = GGML_TYPE_F16; }; +template<> struct ggml_type_trait { static constexpr ggml_type id = GGML_TYPE_Q8_0; }; + //////////////////////////////////////////////////////////////////////////////////////////////////// // VECTORIZED ARITHMETIC OPERATIONS @@ -240,7 +275,7 @@ template <> inline __m512 load(const ggml_fp16_t *p) { //////////////////////////////////////////////////////////////////////////////////////////////////// // FLOATING POINT MATRIX MULTIPLICATION -template +template class tinyBLAS { public: tinyBLAS(int k, @@ -410,8 +445,8 @@ class tinyBLAS { template NOINLINE void gemm(int m0, int m, int n0, int n) { - int ytiles = (m - m0) / RM; - int xtiles = (n - n0) / RN; + int ytiles = RM > 1 ? (m - m0) / RM : 1; + int xtiles = RN > 1 ? (n - n0) / RN : 1; int tiles = xtiles * ytiles; int duty = (tiles + nth - 1) / nth; int start = duty * ith; @@ -425,12 +460,12 @@ class tinyBLAS { for (int l = 0; l < k; l += KN) for (int j = 0; j < RN; ++j) for (int i = 0; i < RM; ++i) - Cv[j][i] = madd(load(A + lda * (ii + i) + l), - load(B + ldb * (jj + j) + l), + Cv[j][i] = madd(load(&INDEX(A, lda, ii + i, l)), + load(&INDEX(B, ldb, jj + j, l)), Cv[j][i]); for (int j = 0; j < RN; ++j) for (int i = 0; i < RM; ++i) - C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); + INDEX(C, ldc, jj + j, ii + i) = hsum(Cv[j][i]); } } @@ -449,7 +484,7 @@ class tinyBLAS { // QUANT ZERO MATRIX MULTIPLICATION #if defined(__ARM_FEATURE_DOTPROD) -template +template class tinyBLAS_Q0_ARM { public: tinyBLAS_Q0_ARM(int k, @@ -525,8 +560,8 @@ class tinyBLAS_Q0_ARM { template NOINLINE void gemm(int m0, int m, int n0, int n) { - int ytiles = (m - m0) / RM; - int xtiles = (n - n0) / RN; + int ytiles = RM > 1 ? (m - m0) / RM : 1; + int xtiles = RN > 1 ? (n - n0) / RN : 1; int tiles = xtiles * ytiles; int duty = (tiles + nth - 1) / nth; int start = duty * ith; @@ -543,15 +578,15 @@ class tinyBLAS_Q0_ARM { Cv[j][i] = vmlaq_n_f32(Cv[j][i], vcvtq_f32_s32(vdotq_s32( vdotq_s32(vdupq_n_s32(0), - load_lo(A + lda * (ii + i) + l), - load_lo(B + ldb * (jj + j) + l)), - load_hi(A + lda * (ii + i) + l), - load_hi(B + ldb * (jj + j) + l))), - unhalf(A[lda * (ii + i) + l].d) * - unhalf(B[ldb * (jj + j) + l].d)); + load_lo(&INDEX(A, lda, ii + i, l)), + load_lo(&INDEX(B, ldb, jj + j, l))), + load_hi(&INDEX(A, lda, ii + i, l)), + load_hi(&INDEX(B, ldb, jj + j, l)))), + unhalf(INDEX(A, lda, ii + i, l).d) * + unhalf(INDEX(B, ldb, jj + j, l).d)); for (int j = 0; j < RN; ++j) for (int i = 0; i < RM; ++i) - C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); + INDEX(C, ldc, jj + j, ii + i) = hsum(Cv[j][i]); } } @@ -587,7 +622,7 @@ class tinyBLAS_Q0_ARM { #endif // __ARM_FEATURE_DOTPROD #if defined(__AVX2__) || defined(__AVX512F__) -template +template class tinyBLAS_Q0_AVX2 { public: tinyBLAS_Q0_AVX2(int k, @@ -715,8 +750,8 @@ class tinyBLAS_Q0_AVX2 { template NOINLINE void gemm(int m0, int m, int n0, int n) { - int ytiles = (m - m0) / RM; - int xtiles = (n - n0) / RN; + int ytiles = RM > 1 ? (m - m0) / RM : 1; + int xtiles = RN > 1 ? (n - n0) / RN : 1; int tiles = xtiles * ytiles; int duty = (tiles + nth - 1) / nth; int start = duty * ith; @@ -730,16 +765,16 @@ class tinyBLAS_Q0_AVX2 { for (int l = 0; l < k; ++l) for (int j = 0; j < RN; ++j) for (int i = 0; i < RM; ++i) - Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) * - unhalf(B[ldb * (jj + j) + l].d)), - updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), - load(A + lda * (ii + i) + l)), - _mm256_sign_epi8(load(B + ldb * (jj + j) + l), - load(A + lda * (ii + i) + l))), + Cv[j][i] = madd(_mm256_set1_ps(unhalf(INDEX(A, lda, ii + i, l).d) * + unhalf(INDEX(B, ldb, jj + j, l).d)), + updot(_mm256_sign_epi8(load(&INDEX(A, lda, ii + i, l)), + load(&INDEX(A, lda, ii + i, l))), + _mm256_sign_epi8(load(&INDEX(B, ldb, jj + j, l)), + load(&INDEX(A, lda, ii + i, l)))), Cv[j][i]); for (int j = 0; j < RN; ++j) for (int i = 0; i < RM; ++i) - C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); + INDEX(C, ldc, jj + j, ii + i) = hsum(Cv[j][i]); } } @@ -839,7 +874,7 @@ bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B, #if defined(__AVX512F__) if (k % 16) return false; - tinyBLAS<16, __m512, __m512, float, float, float> tb{ + tinyBLAS<0, 16, __m512, __m512, float, float, float> tb{ k, (const float *)A, lda, (const float *)B, ldb, (float *)C, ldc, @@ -849,7 +884,7 @@ bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B, #elif defined(__AVX__) || defined(__AVX2__) if (k % 8) return false; - tinyBLAS<8, __m256, __m256, float, float, float> tb{ + tinyBLAS<0, 8, __m256, __m256, float, float, float> tb{ k, (const float *)A, lda, (const float *)B, ldb, (float *)C, ldc, @@ -861,7 +896,7 @@ bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B, return false; if (k % 4) return false; - tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{ + tinyBLAS<0, 4, float32x4_t, float32x4_t, float, float, float> tb{ k, (const float *)A, lda, (const float *)B, ldb, (float *)C, ldc, @@ -879,7 +914,7 @@ bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B, return false; if (Btype != GGML_TYPE_F32) return false; - tinyBLAS<16, __m512, __m512, ggml_fp16_t, float, float> tb{ + tinyBLAS<0, 16, __m512, __m512, ggml_fp16_t, float, float> tb{ k, (const ggml_fp16_t *)A, lda, (const float *)B, ldb, (float *)C, ldc, @@ -891,7 +926,7 @@ bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B, return false; if (Btype != GGML_TYPE_F32) return false; - tinyBLAS<8, __m256, __m256, ggml_fp16_t, float, float> tb{ + tinyBLAS<0, 8, __m256, __m256, ggml_fp16_t, float, float> tb{ k, (const ggml_fp16_t *)A, lda, (const float *)B, ldb, (float *)C, ldc, @@ -905,7 +940,7 @@ bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B, return false; if (Btype != GGML_TYPE_F16) return false; - tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{ + tinyBLAS<0, 8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{ k, (const ggml_fp16_t *)A, lda, (const ggml_fp16_t *)B, ldb, (float *)C, ldc, @@ -917,7 +952,7 @@ bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B, return false; if (Btype != GGML_TYPE_F32) return false; - tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{ + tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{ k, (const ggml_fp16_t *)A, lda, (const float *)B, ldb, (float *)C, ldc, @@ -933,7 +968,7 @@ bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B, if (Btype != GGML_TYPE_Q8_0) return false; #if defined(__AVX2__) || defined(__AVX512F__) - tinyBLAS_Q0_AVX2 tb{ + tinyBLAS_Q0_AVX2<0, block_q8_0, block_q8_0, float> tb{ k, (const block_q8_0 *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, @@ -941,7 +976,7 @@ bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B, tb.matmul(m, n, task); return true; #elif defined(__ARM_FEATURE_DOTPROD) - tinyBLAS_Q0_ARM tb{ + tinyBLAS_Q0_ARM<0, block_q8_0> tb{ k, (const block_q8_0 *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, @@ -957,7 +992,7 @@ bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B, if (Btype != GGML_TYPE_Q8_0) return false; #if defined(__AVX2__) || defined(__AVX512F__) - tinyBLAS_Q0_AVX2 tb{ + tinyBLAS_Q0_AVX2<0, block_q4_0, block_q8_0, float> tb{ k, (const block_q4_0 *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, @@ -965,7 +1000,7 @@ bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B, tb.matmul(m, n, task); return true; #elif defined(__ARM_FEATURE_DOTPROD) - tinyBLAS_Q0_ARM tb{ + tinyBLAS_Q0_ARM<0, block_q4_0> tb{ k, (const block_q4_0 *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, @@ -997,3 +1032,367 @@ bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B, (void)Btype; (void)Ctype; } + +// +// _ _ ___ _ _ ___ +// | |_(_)_ _ _ _| _ ) | /_\ / __| +// | _| | ' \ || | _ \ |__ / _ \\__ \. +// \__|_|_||_\_, |___/____/_/ \_\___/ +// |__/ +// +// MIXTURE OF EXPERTS TENSOR MULTIPLICATION +// +// +// SHAPES +// +// - weights [cols, rows, experts] +// - thought [cols, tasks, tokens] w/ tasks ≤ thinkers +// - result [rows, thinkers, tokens] w/ thinkers ≤ experts +// - plan [thinkers, tokens] w/ i32 < experts +// +// DEFINITION +// +// for thinker in range(thinkers): +// for token in range(tokens): +// for row in range(rows): +// c = 0 +// for col in range(cols): +// expert = plan[token][thinker] +// a = weights[expert][row][col] +// b = thought[token][thinker % tasks][col] +// c += a * b +// result[token][thinker][row] = c +// +// REGULARITIES +// +// - tokens can be odd +// - thinkers is usually 2 +// - tasks is usually 1 or 2 +// - cols should be a multiple of 64 +// - rows should be a multiple of 64 +// - experts is usually 8 but could be 60 +// - tokens is always 1 for token generation +// - tokens can be huge for prompt processing +// +// EXAMPLE +// +// mixtral 8x7b w/ 217 token prompt +// +// | ne*0 ne*1 ne*2 ne*3 | nb*0 nb*1 nb*2 nb*3 | type +// ========================================================================= +// weights | 16384 6144 8 1 | 18 0x2400 0x3600000 0x1b000000 | q4_0 +// thought | 16384 2 217 1 | 4 0x10000 0x20000 0x1b20000 | f32 +// result | 6144 2 217 1 | 4 0x6000 0xc000 0xa2c000 | f32 +// plan | 2 217 1 1 | 4 0x20 0x1b20 0x1b20 | i32 +// + +namespace { +class MixMul { + public: + MixMul(const ggml_compute_params *params, + const ggml_tensor *weights, + const ggml_tensor *thought, + const ggml_tensor *plan, + ggml_tensor *result) + : params(params), + weights(weights), + thought(thought), + plan(plan), + result(result), + rows(weights->ne[1]), + cols(weights->ne[0]), + experts(weights->ne[2]), + thinkers(plan->ne[0]), + tasks(thought->ne[1]), + tokens(thought->ne[2]), + ldq((cols * 2 + ROW_ALIGN - 1) & -ROW_ALIGN), + wdata_((char *)(((uintptr_t)params->wdata + MAX_ALIGN - 1) & -MAX_ALIGN)), + allocated_(0) { + } + + bool allocate_shared_memory() { + if (!(quantized_thought_ = allocate(MATRIX_ALIGN, tokens * tasks * ldq))) + return false; + if (!(rowptr_result_ = allocate(ROW_ALIGN, experts * tokens * thinkers))) + return false; + if (!(rowptr_thought_ = allocate(ROW_ALIGN, experts * tokens * thinkers))) + return false; + if (!(rowptr_count_ = allocate(sizeof(int), experts))) + return false; + return true; + } + + size_t get_allocated_bytes() { + return (wdata_ - (char *)params->wdata) + allocated_; + } + + bool mixmul() { + + // invariants + assert(tasks <= thinkers); + assert(thinkers <= experts); + assert(tokens == plan->ne[1]); + assert(rows == result->ne[0]); + assert(cols == thought->ne[0]); + assert(tokens == result->ne[2]); + assert(thinkers == result->ne[1]); + + // dimensionality + assert(plan->ne[2] == 1); + assert(plan->ne[3] == 1); + assert(result->ne[3] == 1); + assert(weights->ne[3] == 1); + assert(thought->ne[3] == 1); + + // miscellaneous + assert(params->nth > 0); + assert(params->ith < params->nth); + assert(plan->type == GGML_TYPE_I32); + + // supported types + if (result->type != GGML_TYPE_F32) + return false; + + // check nb01 is convertible to lda + if (weights->nb[1] % ggml_type_size(weights->type)) + return false; + + // no support for column strides + if (result->nb[0] != ggml_type_size(result->type)) + return false; + if (thought->nb[0] != ggml_type_size(thought->type)) + return false; + if (weights->nb[0] != ggml_type_size(weights->type)) + return false; + + switch (weights->type) { + + case GGML_TYPE_F32: + if (thought->type != GGML_TYPE_F32) + return false; +#if defined(__AVX512F__) + return mixmat<16, 1, tinyBLAS, + float, float, float>(); +#elif defined(__AVX__) || defined(__AVX2__) + return mixmat<8, 1, tinyBLAS, + float, float, float>(); +#elif defined(__SSE__) + return mixmat<4, 1, tinyBLAS, + float, float, float>(); +#elif defined(__ARM_NEON) + return mixmat<4, 1, tinyBLAS, + float, float, float>(); +#else + return false; +#endif + + case GGML_TYPE_F16: + if (thought->type != GGML_TYPE_F32 && + thought->type != GGML_TYPE_F16) + return false; +#if defined(__AVX512F__) + return mixmat<16, 1, tinyBLAS, + ggml_fp16_t, ggml_fp16_t, float>(); +#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__) + return mixmat<8, 1, tinyBLAS, + ggml_fp16_t, ggml_fp16_t, float>(); +#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) + return mixmat<8, 1, tinyBLAS, + ggml_fp16_t, ggml_fp16_t, float>(); +#elif defined(__ARM_NEON) && !defined(_MSC_VER) + return mixmat<4, 1, tinyBLAS, + ggml_fp16_t, ggml_fp16_t, float>(); +#else + return false; +#endif + + case GGML_TYPE_Q4_0: + if (thought->type != GGML_TYPE_F32 && + thought->type != GGML_TYPE_Q8_0) + return false; +#if defined(__AVX2__) || defined(__AVX512F__) + return mixmat<32, 32, tinyBLAS_Q0_AVX2, + block_q4_0, block_q8_0, float>(); +#elif defined(__ARM_FEATURE_DOTPROD) + return mixmat<32, 32, tinyBLAS_Q0_ARM, + block_q4_0, block_q8_0, float>(); +#else + return false; +#endif + + case GGML_TYPE_Q8_0: + if (thought->type != GGML_TYPE_F32 && + thought->type != GGML_TYPE_Q8_0) + return false; +#if defined(__AVX2__) || defined(__AVX512F__) + return mixmat<32, 32, tinyBLAS_Q0_AVX2, + block_q8_0, block_q8_0, float>(); +#elif defined(__ARM_FEATURE_DOTPROD) + return mixmat<32, 32, tinyBLAS_Q0_ARM, + block_q8_0, block_q8_0, float>(); +#else + return false; +#endif + + default: + return false; + } + } + + private: + template + bool mixmat() { + if (cols % KN) + return false; + switch (params->type) { + case GGML_TASK_TYPE_INIT: + if (thought->type != ggml_type_trait::id) + quantize_thought(ggml_type_trait::id); + build_row_pointers(ggml_type_trait::id); + return true; + case GGML_TASK_TYPE_COMPUTE: + assert(!(cols % BS)); + assert(!(weights->nb[1] % sizeof(TA))); + for (int expert = 0; expert < experts; ++expert) { + BLAS tb{cols / BS, + (const TA *)((const char *)weights->data + expert*weights->nb[2]), + (int)(weights->nb[1] / sizeof(TA)), + (const TB *)(rowptr_thought_ + expert*tokens*thinkers), 0, + (TC *)(rowptr_result_ + expert*tokens*thinkers), 0, + params->ith, params->nth}; + tb.matmul(rows, rowptr_count_[expert], GGML_TASK_TYPE_COMPUTE); + } + return true; + default: + return true; + } + } + + void build_row_pointers(ggml_type vec_dot_type) { + for (int expert = params->ith; expert < experts; expert += params->nth) { + int count = 0; + for (int token = 0; token < tokens; ++token) + for (int thinker = 0; thinker < thinkers; ++thinker) + if (expert == *(const int *)((const char *)plan->data + + token*plan->nb[1] + + thinker*plan->nb[0])) { + int row = count++; + int idx = expert*thinkers*tokens + row; + rowptr_result_[idx] = (uintptr_t)((char *)result->data + + token*result->nb[2] + + thinker*result->nb[1]); + if (thought->type == vec_dot_type) + rowptr_thought_[idx] = (uintptr_t)((char *)thought->data + + token*thought->nb[2] + + thinker%tasks*thought->nb[1]); + else + rowptr_thought_[idx] = (uintptr_t)((char *)quantized_thought_ + + token*tasks*ldq + + thinker%tasks*ldq); + } + rowptr_count_[expert] = count; + } + } + + void quantize_thought(ggml_type vec_dot_type) { + int chore = 0; + for (int token = 0; token < tokens; ++token) + for (int task = 0; task < tasks; ++task) + if (chore++ % params->nth == params->ith) + quantize_row(quantized_thought_ + token*tasks*ldq + task*ldq, + (const float *)((const char *)thought->data + + token*thought->nb[2] + + task*thought->nb[1]), + vec_dot_type); + } + + void quantize_row(void *dst, const float *src, ggml_type type) { + assert((int)ggml_row_size(type, cols) <= ldq); + switch (type) { + case GGML_TYPE_F16: + ggml_fp32_to_fp16_row(src, (ggml_fp16_t *)dst, cols); + break; + case GGML_TYPE_Q8_0: + quantize_row_q8_0((const float *)src, (block_q8_0 *)dst, cols); + break; + default: + GGML_UNREACHABLE(); + } + } + + template + T *allocate(int align, int elems) { + T *res = nullptr; + size_t need = sizeof(T) * elems; + size_t base = allocated_; + base += align - 1; + base &= -align; + size_t toto = base + need; + if (toto >= allocated_ && toto <= params->wsize && elems >= 0) { + res = (T *)(wdata_ + base); + allocated_ = toto; + } + return res; + } + + const ggml_compute_params *const params; + const ggml_tensor *const weights; + const ggml_tensor *const thought; + const ggml_tensor *const plan; + ggml_tensor *const result; + const int rows; + const int cols; + const int experts; + const int thinkers; + const int tasks; + const int tokens; + const int ldq; + + // variables + char *const wdata_; + size_t allocated_; + + // shared memory + int *rowptr_count_/*[experts]*/; + char *quantized_thought_/*[tokens][tasks][cols][2]*/; + uintptr_t *rowptr_result_/*[experts][tokens*thinkers]*/; + uintptr_t *rowptr_thought_/*[experts][tokens*thinkers]*/; +}; +} // namespace + +/** + * Performs "mixture of experts" tensor multiplication on CPU. + */ +bool llamafile_mixmul(const ggml_compute_params *params, + const ggml_tensor *weights, + const ggml_tensor *thought, + const ggml_tensor *plan, + ggml_tensor *result) { + MixMul mm{params, weights, thought, plan, result}; + return mm.allocate_shared_memory() && + mm.mixmul(); +} + +/** + * Returns number of shared memory bytes llamafile_mixmul() needs. + */ +size_t llamafile_mixmul_needs(const ggml_tensor *weights, + const ggml_tensor *thought, + const ggml_tensor *plan) { + ggml_compute_params params{}; + params.wsize = 0x7ffff000; + params.wdata = (void *)0x1000; + MixMul mm{¶ms, weights, thought, plan, 0}; + if (mm.allocate_shared_memory()) + return mm.get_allocated_bytes(); + else + return 0; +} diff --git a/sgemm.h b/sgemm.h index da23b209c4dd5b..5921cbc9930ef2 100644 --- a/sgemm.h +++ b/sgemm.h @@ -1,12 +1,24 @@ #pragma once +#include #include #ifdef __cplusplus extern "C" { #endif +struct ggml_tensor; +struct ggml_compute_params; + bool llamafile_sgemm(int, int, int, const void *, int, const void *, int, void *, int, int, int, int, int, int, int); +bool llamafile_mixmul(const struct ggml_compute_params *, const struct ggml_tensor *, + const struct ggml_tensor *, const struct ggml_tensor *, + struct ggml_tensor *); + +size_t llamafile_mixmul_needs(const struct ggml_tensor *, + const struct ggml_tensor *, + const struct ggml_tensor *); + #ifdef __cplusplus } #endif