Skip to content

Commit

Permalink
backend-cpu: add online flow for aarch64 Q4_0 GEMV/GEMM kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
chaxu01 committed Nov 6, 2024
1 parent b11f9ba commit 639949f
Show file tree
Hide file tree
Showing 10 changed files with 870 additions and 91 deletions.
7 changes: 7 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2047,6 +2047,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
common_log_set_timestamps(common_log_main(), true);
}
).set_env("LLAMA_LOG_TIMESTAMPS"));
add_opt(common_arg(
{"-rtrp", "--runtime-repack"},
string_format("Allow runtime requantization and repacking of Q4_0 to enable optimized GEMM and GEMV kernels (default: %d)", params.runtime_repack),
[](common_params & params) {
params.runtime_repack = true;
}
).set_examples({LLAMA_EXAMPLE_MAIN}));

return ctx_arg;
}
3 changes: 2 additions & 1 deletion common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -983,7 +983,7 @@ struct llama_model_params common_model_params_to_llama(const common_params & par
mparams.main_gpu = params.main_gpu;
mparams.split_mode = params.split_mode;
mparams.tensor_split = params.tensor_split;
mparams.use_mmap = params.use_mmap;
mparams.use_mmap = params.use_mmap && !params.runtime_repack;
mparams.use_mlock = params.use_mlock;
mparams.check_tensors = params.check_tensors;
if (params.kv_overrides.empty()) {
Expand Down Expand Up @@ -1053,6 +1053,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
cparams.offload_kqv = !params.no_kv_offload;
cparams.flash_attn = params.flash_attn;
cparams.no_perf = params.no_perf;
cparams.runtime_repack = params.runtime_repack;

if (params.reranking) {
cparams.embeddings = true;
Expand Down
2 changes: 2 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,8 @@ struct common_params {
bool warmup = true; // warmup run
bool check_tensors = false; // validate tensor data

bool runtime_repack = false; // runtime repack weight for optimized kernels

std::string cache_type_k = "f16"; // KV cache data type for the K
std::string cache_type_v = "f16"; // KV cache data type for the V

Expand Down
196 changes: 112 additions & 84 deletions examples/llama-bench/llama-bench.cpp

Large diffs are not rendered by default.

14 changes: 13 additions & 1 deletion ggml/include/ggml-backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,19 @@ extern "C" {
GGML_API void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);
GGML_API void ggml_backend_view_init(struct ggml_tensor * tensor);

// CPU buffer types are always available
//
// CPU backend
//

GGML_API ggml_backend_t ggml_backend_cpu_init(void);

GGML_API bool ggml_backend_is_cpu (ggml_backend_t backend);
GGML_API void ggml_backend_cpu_set_n_threads (ggml_backend_t backend_cpu, int n_threads);
GGML_API void ggml_backend_cpu_set_threadpool (ggml_backend_t backend_cpu, ggml_threadpool_t threadpool);
GGML_API void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data);
GGML_API void ggml_backend_cpu_set_runtime_repack(ggml_backend_t backend_cpu, bool runtime_repack);

// Create a backend buffer from an existing pointer
GGML_API ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size);
GGML_API ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void);

Expand Down
99 changes: 99 additions & 0 deletions ggml/src/ggml-aarch64.c
Original file line number Diff line number Diff line change
Expand Up @@ -3476,3 +3476,102 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
}
}
}

static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor *t, int interleave_block, uint8_t **pmem, size_t *psize) {
GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
GGML_ASSERT(t->ne[0] % 8 == 0);
GGML_ASSERT(interleave_block == 4 || interleave_block == 8);

// Do in-place transformation. Allocate scratch buffer
size_t size = sizeof(block_q4_0x4) * t->ne[0] / QK4_0;
if (size > *psize) {
uint8_t *new_mem = realloc(*pmem, size);
if (!new_mem) {
return -1;
}
*pmem = new_mem;
*psize = size;
}
block_q4_0x4 *dst = (block_q4_0x4*) *pmem;
block_q4_0 *src = (block_q4_0*) t->data;
block_q4_0 dst_tmp[4];
int n = t->ne[0];
int nrow = t->ne[1]; // Number of rows
int nrows_interleaved = 4;
int nblocks = t->ne[0] / QK4_0;
for (int b = 0; b < (nrow * n); b += nrows_interleaved * n) {
int cnt = 0;
for (int64_t x = 0; x < nblocks; x++) {
for (int i = 0; i < nrows_interleaved; i++ ) {
dst_tmp[i] = src[x + i * nblocks];
}
dst[cnt++] = make_block_q4_0x4(dst_tmp, interleave_block, 0x88);
}
memcpy(src, dst, size);
src += cnt * 4;
}
return 0;
}

static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor *t, int interleave_block, uint8_t **pmem, size_t *psize) {
GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
GGML_ASSERT(t->ne[0] % 8 == 0);
GGML_ASSERT(interleave_block == 8);

// Do in-place transformation. Allocate scratch buffer
size_t size = sizeof(block_q4_0x8) * t->ne[0] / QK4_0;
if (size > *psize) {
uint8_t *new_mem = realloc(*pmem, size);
if (!new_mem) {
return -1;
}
*pmem = new_mem;
*psize = size;
}
block_q4_0x8 *dst = (block_q4_0x8*) *pmem;
block_q4_0 *src = (block_q4_0*) t->data;
block_q4_0 dst_tmp[8];
int n = t->ne[0];
int nrow = t->ne[1]; // Number of rows
int nrows_interleaved = 8;
int nblocks = t->ne[0] / QK4_0;
for (int b = 0; b < (nrow * n); b += nrows_interleaved * n) {
int cnt = 0;
for (int64_t x = 0; x < nblocks; x++) {
for (int i = 0; i < nrows_interleaved; i++ ) {
dst_tmp[i] = src[x + i * nblocks];
}
dst[cnt++] = make_block_q4_0x8(dst_tmp, interleave_block, 0x88);
}
memcpy(src, dst, size);
src += cnt * 4;
}
return 0;
}

// Prepare for optimized kernels if applicable
void ggml_prepare_optimal_kernel(struct ggml_tensor *cur, uint8_t **pmem, size_t *psize) {
UNUSED(cur);
UNUSED(pmem);
UNUSED(psize);

#if defined(__ARM_ARCH)
if (cur->type == GGML_TYPE_Q4_0) {
if (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0) {
if (repack_q4_0_to_q4_0_8_bl(cur, 8, pmem, psize) == 0) {
cur->type = GGML_TYPE_Q4_0_8_8;
}
}
else if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
if (repack_q4_0_to_q4_0_4_bl(cur, 8, pmem, psize) == 0) {
cur->type = GGML_TYPE_Q4_0_4_8;
}
}
else if (ggml_cpu_has_neon()) {
if (repack_q4_0_to_q4_0_4_bl(cur, 4, pmem, psize) == 0) {
cur->type = GGML_TYPE_Q4_0_4_4;
}
}
}
#endif
}
2 changes: 2 additions & 0 deletions ggml/src/ggml-aarch64.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);

void ggml_prepare_optimal_kernel(struct ggml_tensor *cur, uint8_t **pmem, size_t *psize);

#ifdef __cplusplus
}
#endif
Expand Down
Loading

0 comments on commit 639949f

Please sign in to comment.