Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : Metal inference #1642

Merged
merged 49 commits into from
Jun 4, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
f85020b
mtl : export the LLaMA computation graph
ggerganov May 29, 2023
98c267f
ci : disable temporary
ggerganov May 29, 2023
b23fe8c
mtl : adapt the MNIST example as starter
ggerganov May 29, 2023
a792cbd
mtl : no need for mtl-export tool, add cli arg for main instead
ggerganov May 29, 2023
897d6d8
mtl : export just a small part of the graph for now to make it easier
ggerganov May 29, 2023
248a8c3
mtl : move MSL code into separate file for easy editing
ggerganov May 29, 2023
a8fd9dc
mtl : initial get_rows_q4_0 kernel
ggerganov May 29, 2023
794704e
mtl : confirmed get_rows_q4_0 is working correctly
ggerganov May 30, 2023
72256eb
mtl : add rms_norm kernel + confirm working
ggerganov May 30, 2023
64afc0b
mtl : add mul kernel + confirm working
ggerganov May 30, 2023
2a24994
mtl : initial mul_mat Q4 kernel (wrong results)
ggerganov May 30, 2023
96d0052
mtl : mul_mat fixes (still wrong)
ggerganov May 30, 2023
29bec00
mtl : another mul_mat Q4 (still does not work)
ggerganov May 30, 2023
b2fd06c
mtl : working mul_mat q4
ggerganov May 30, 2023
6af6a05
ggml : fix handling of "view" ops in ggml_graph_import()
ggerganov May 31, 2023
1213af7
mtl : add rope kernel
ggerganov May 31, 2023
7ca81e9
mtl : add reshape and transpose handling
ggerganov May 31, 2023
94ea9e7
ggml : store offset as opt arg for ggml_view_xd() operators
ggerganov Jun 1, 2023
948fcfd
mtl : add cpy kernel + handle view ops
ggerganov Jun 1, 2023
51efb59
mtl : confirm f16 x f32 attention mul mat
ggerganov Jun 1, 2023
0f1c580
mtl : add scale kernel
ggerganov Jun 1, 2023
17a7036
mtl : add diag_mask_inf kernel
ggerganov Jun 1, 2023
17930fb
mtl : fix soft_max kernel
ggerganov Jun 1, 2023
f67c2d8
ggml : update ggml_nbytes() to handle non-contiguous tensors
ggerganov Jun 1, 2023
a266c26
mtl : verify V tensor contents
ggerganov Jun 1, 2023
a0cc3de
mtl : add f32 -> f32 cpy kernel
ggerganov Jun 1, 2023
42dca40
mtl : add silu kernel
ggerganov Jun 1, 2023
fbd3f62
mtl : add non-broadcast mul kernel
ggerganov Jun 1, 2023
9665429
mtl : full GPU inference of the computation graph
ggerganov Jun 1, 2023
f0196a7
mtl : optimize rms_norm and soft_max kernels
ggerganov Jun 1, 2023
e55f7b0
mtl : add f16 mat x f32 vec multiplication kernel
ggerganov Jun 1, 2023
3367146
mtl : fix bug in f16 x f32 mul mat + speed-up computation
ggerganov Jun 2, 2023
847bbfe
mtl : faster mul_mat_q4_0_f32 kernel
ggerganov Jun 2, 2023
70c3387
mtl : fix kernel signature + roll inner loop
ggerganov Jun 2, 2023
b088e14
mtl : more threads for rms_norm + better timing
ggerganov Jun 2, 2023
6276057
mtl : remove printfs from inner loop
ggerganov Jun 2, 2023
03c2d72
mtl : simplify implementation
ggerganov Jun 2, 2023
640a889
mtl : add save/load vocab to ggml file
ggerganov Jun 2, 2023
2f4e9d1
mtl : plug Metal inference into llama.cpp (very quick-n-dirty)
ggerganov Jun 2, 2023
4df2ef3
mtl : make it work with main example
ggerganov Jun 3, 2023
18e482a
mtl : preparing for merge
ggerganov Jun 4, 2023
e4b5222
mtl : clean-up ggml mtl interface + suport scratch / inplace
ggerganov Jun 4, 2023
e26cd6b
mtl : remove temp / debug code
ggerganov Jun 4, 2023
a7fb899
metal : final refactoring and simplification
ggerganov Jun 4, 2023
d8a7486
Revert "ci : disable temporary"
ggerganov Jun 4, 2023
b252acb
metal : add comments
ggerganov Jun 4, 2023
db3db9e
metal : clean-up stuff, fix typos
ggerganov Jun 4, 2023
e33002d
readme : add Metal instructions
ggerganov Jun 4, 2023
324e823
readme : add example for main
ggerganov Jun 4, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 0 additions & 17 deletions .github/workflows/editorconfig.yml

This file was deleted.

20 changes: 0 additions & 20 deletions .github/workflows/tidy-post.yml

This file was deleted.

23 changes: 0 additions & 23 deletions .github/workflows/tidy-review.yml

This file was deleted.

1 change: 1 addition & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ else()
add_subdirectory(save-load-state)
add_subdirectory(benchmark)
add_subdirectory(baby-llama)
add_subdirectory(mtl)
if(LLAMA_BUILD_SERVER)
add_subdirectory(server)
endif()
Expand Down
3 changes: 3 additions & 0 deletions examples/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.use_mmap = false;
} else if (arg == "--mtest") {
params.mem_test = true;
} else if (arg == "--export") {
params.export_cgraph = true;
ggerganov marked this conversation as resolved.
Show resolved Hide resolved
} else if (arg == "--verbose-prompt") {
params.verbose_prompt = true;
} else if (arg == "-r" || arg == "--reverse-prompt") {
Expand Down Expand Up @@ -438,6 +440,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " number of layers to store in VRAM\n");
#endif
fprintf(stderr, " --mtest compute maximum memory usage\n");
fprintf(stderr, " --export export the computation graph to 'llama.ggml'\n");
fprintf(stderr, " --verbose-prompt print prompt before generation\n");
fprintf(stderr, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
fprintf(stderr, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
Expand Down
1 change: 1 addition & 0 deletions examples/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ struct gpt_params {
bool use_mmap = true; // use mmap for faster loads
bool use_mlock = false; // use mlock to keep model in memory
bool mem_test = false; // compute maximum memory usage
bool export_cgraph = false; // export the computation graph
bool verbose_prompt = false; // print prompt tokens before generation
};

Expand Down
7 changes: 7 additions & 0 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,13 @@ int main(int argc, char ** argv) {
return 0;
}

// export the cgraph and exit
if (params.export_cgraph) {
llama_eval_export(ctx, "llama.ggml");
llama_free(ctx);

return 0;
}

std::string path_session = params.path_prompt_cache;
std::vector<llama_token> session_tokens;
Expand Down
33 changes: 33 additions & 0 deletions examples/mtl/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
if (APPLE)
#
# mtl

find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
find_library(METAL_FRAMEWORK Metal REQUIRED)
find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
find_library(METALPERFORMANCE_FRAMEWORK MetalPerformanceShaders REQUIRED)

set(TEST_TARGET mtl)
add_executable(${TEST_TARGET} mtl.cpp mtl.h mtl.m)
target_link_libraries(${TEST_TARGET} PRIVATE
ggml
${FOUNDATION_LIBRARY}
${METAL_FRAMEWORK}
${METALKIT_FRAMEWORK}
${METALPERFORMANCE_FRAMEWORK}
)

# TODO: temporary until the kernels are ready
# custom command to build mtl.metal into a library
# depends on the mtl.metal file
add_custom_target(mtl.metallib-tmp ALL DEPENDS ${CMAKE_BINARY_DIR}/mtl.metallib)

add_custom_command(
OUTPUT ${CMAKE_BINARY_DIR}/mtl.metallib
COMMAND xcrun -sdk macosx metal -c ${CMAKE_CURRENT_SOURCE_DIR}/mtl.metal -o ${CMAKE_BINARY_DIR}/mtl.air
COMMAND xcrun -sdk macosx metallib ${CMAKE_BINARY_DIR}/mtl.air -o ${CMAKE_BINARY_DIR}/mtl.metallib
DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/mtl.metal
COMMENT "Building mtl.metallib"
)
endif()

51 changes: 51 additions & 0 deletions examples/mtl/mtl.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#include "ggml.h"
#include "mtl.h"

#include <cstdio>
#include <cstring>
#include <cstdlib>

int main(int argc, char ** argv) {
ggml_time_init();

if (argc != 2) {
fprintf(stderr, "Usage: %s llama.ggml\n", argv[0]);
return -1;
}

const char * fname_cgraph = argv[1];

// load the compute graph
struct ggml_context * ctx_data = NULL;
struct ggml_context * ctx_eval = NULL;

struct ggml_cgraph gf = ggml_graph_import(fname_cgraph, &ctx_data, &ctx_eval);
gf.n_threads = 1;

// allocate work context
static size_t buf_size = gf.work_size; // TODO
static void * buf = malloc(buf_size);

struct ggml_init_params params = {
/*.mem_size =*/ buf_size,
/*.mem_buffer =*/ buf,
/*.no_alloc =*/ false,
};

struct ggml_context * ctx_work = ggml_init(params);

// this allocates all Metal resources and memory buffers
auto * ctx_mtl = llama_mtl_init(ctx_data, ctx_eval, ctx_work, &gf);

// the actual inference happens here
llama_mtl_eval(ctx_mtl, &gf);

llama_mtl_free(ctx_mtl);

ggml_free(ctx_work);
ggml_free(ctx_data);
ggml_free(ctx_eval);

return 0;
}

28 changes: 28 additions & 0 deletions examples/mtl/mtl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#pragma once

struct ggml_context;
struct ggml_cgraph;

#ifdef __cplusplus
extern "C" {
#endif

struct ggml_mtl_context;

struct ggml_mtl_context * llama_mtl_init(
struct ggml_context * ctx_data,
struct ggml_context * ctx_eval,
struct ggml_context * ctx_work,
struct ggml_cgraph * gf);

void llama_mtl_free(struct ggml_mtl_context * ctx);

// return 0 on success
int llama_mtl_eval(
struct ggml_mtl_context * ctx,
struct ggml_cgraph * gf);

#ifdef __cplusplus
}
#endif

Loading